diff --git a/docsrc/firebolt.common.rst b/docsrc/firebolt.common.rst index d2c022d7f6b..561ec83516d 100644 --- a/docsrc/firebolt.common.rst +++ b/docsrc/firebolt.common.rst @@ -42,7 +42,7 @@ Util --------------------------- .. automodule:: firebolt.common.util - :exclude-members: async_to_sync, cached_property, fix_url_schema, mixin_for, prune_dict + :exclude-members: cached_property, fix_url_schema, mixin_for, prune_dict :members: :undoc-members: :show-inheritance: diff --git a/setup.cfg b/setup.cfg index cc209eb844a..fee2c304831 100755 --- a/setup.cfg +++ b/setup.cfg @@ -26,13 +26,15 @@ install_requires = aiorwlock==1.1.0 appdirs>=1.4.4 appdirs-stubs>=0.1.0 + async-property>=0.2.1 cryptography>=3.4.0 httpx[http2]==0.24.0 pydantic[dotenv]>=1.8.2 python-dateutil>=2.8.2 - readerwriterlock==1.0.9 + readerwriterlock>=1.0.9 sqlparse>=0.4.2 - trio<0.22.0 + tricycle>=0.2.2 + trio>=0.22.0 python_requires = >=3.7 include_package_data = True package_dir = @@ -50,12 +52,12 @@ dev = mypy==0.910 pre-commit==2.15.0 pyfakefs>=4.5.3 - pytest==6.2.5 - pytest-asyncio==0.19.0 + pytest==7.2.0 pytest-cov==3.0.0 pytest-httpx==0.22.0 pytest-mock==3.6.1 pytest-timeout==2.1.0 + pytest-trio==0.8.0 pytest-xdist==2.5.0 trio-typing[mypy]==0.6.* types-cryptography==3.3.18 @@ -90,4 +92,4 @@ docstring-convention = google inline-quotes = " [tool:pytest] -asyncio_mode = auto +trio_mode = true diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index c7d2dc21fa2..f73fb28627a 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -2,128 +2,38 @@ import logging import socket -from json import JSONDecodeError from types import TracebackType from typing import Any, Dict, List, Optional from httpcore.backends.auto import AutoBackend from httpcore.backends.base import AsyncNetworkStream -from httpx import AsyncHTTPTransport, HTTPStatusError, RequestError, Timeout +from httpx import AsyncHTTPTransport, Timeout from firebolt.async_db.cursor import Cursor +from firebolt.async_db.util import ( + _get_engine_url_status_db, + _get_system_engine_url, +) from firebolt.client import DEFAULT_API_URL, AsyncClient -from firebolt.client.auth import Auth, _get_auth +from firebolt.client.auth import Auth from firebolt.common.base_connection import BaseConnection from firebolt.common.settings import ( DEFAULT_TIMEOUT_SECONDS, KEEPALIVE_FLAG, KEEPIDLE_RATE, ) -from firebolt.common.util import validate_engine_name_and_url from firebolt.utils.exception import ( ConfigurationError, ConnectionClosedError, - FireboltEngineError, + EngineNotRunningError, InterfaceError, ) -from firebolt.utils.urls import ( - ACCOUNT_ENGINE_ID_BY_NAME_URL, - ACCOUNT_ENGINE_URL, - ACCOUNT_ENGINE_URL_BY_DATABASE_NAME, -) from firebolt.utils.usage_tracker import get_user_agent_header from firebolt.utils.util import fix_url_schema -AUTH_CREDENTIALS_DEPRECATION_MESSAGE = """ Passing connection credentials - directly to the `connect` function is deprecated. - Pass the `Auth` object instead. - Examples: - >>> from firebolt.client.auth import UsernamePassword - >>> ... - >>> connect(auth=UsernamePassword(username, password), ...) - or - >>> from firebolt.client.auth import Token - >>> ... - >>> connect(auth=Token(access_token), ...)""" - logger = logging.getLogger(__name__) -async def _resolve_engine_url( - engine_name: str, - auth: Auth, - api_endpoint: str, - account_name: Optional[str] = None, -) -> str: - async with AsyncClient( - auth=auth, - base_url=api_endpoint, - account_name=account_name, - api_endpoint=api_endpoint, - timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), - ) as client: - account_id = await client.account_id - url = ACCOUNT_ENGINE_ID_BY_NAME_URL.format(account_id=account_id) - try: - response = await client.get( - url=url, - params={"engine_name": engine_name}, - ) - response.raise_for_status() - engine_id = response.json()["engine_id"]["engine_id"] - url = ACCOUNT_ENGINE_URL.format(account_id=account_id, engine_id=engine_id) - response = await client.get(url=url) - response.raise_for_status() - return response.json()["engine"]["endpoint"] - except HTTPStatusError as e: - # Engine error would be 404. - if e.response.status_code != 404: - raise InterfaceError( - f"Error {e.__class__.__name__}: Unable to retrieve engine " - f"endpoint {url}." - ) - # Once this is point is reached we've already authenticated with - # the backend so it's safe to assume the cause of the error is - # missing engine. - raise FireboltEngineError(f"Firebolt engine {engine_name} does not exist.") - except (JSONDecodeError, RequestError, RuntimeError) as e: - raise InterfaceError( - f"Error {e.__class__.__name__}: " - f"Unable to retrieve engine endpoint {url}." - ) - - -async def _get_database_default_engine_url( - database: str, - auth: Auth, - api_endpoint: str, - account_name: Optional[str] = None, -) -> str: - async with AsyncClient( - auth=auth, - base_url=api_endpoint, - account_name=account_name, - api_endpoint=api_endpoint, - timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), - ) as client: - try: - account_id = await client.account_id - response = await client.get( - url=ACCOUNT_ENGINE_URL_BY_DATABASE_NAME.format(account_id=account_id), - params={"database_name": database}, - ) - response.raise_for_status() - return response.json()["engine_url"] - except ( - JSONDecodeError, - RequestError, - RuntimeError, - HTTPStatusError, - KeyError, - ) as e: - raise InterfaceError(f"Unable to retrieve default engine endpoint: {e}.") - - class OverriddenHttpBackend(AutoBackend): """ `OverriddenHttpBackend` is a short-term solution for the TCP @@ -189,7 +99,6 @@ class Connection(BaseConnection): """ - client_class: type __slots__ = ( "_client", "_cursors", @@ -197,16 +106,20 @@ class Connection(BaseConnection): "engine_url", "api_endpoint", "_is_closed", + "_system_engine_connection", ) def __init__( self, engine_url: str, - database: str, + database: Optional[str], auth: Auth, - api_endpoint: str = DEFAULT_API_URL, + account_name: str, + system_engine_connection: Optional["Connection"], + api_endpoint: str, additional_parameters: Dict[str, Any] = {}, ): + super().__init__() self.api_endpoint = api_endpoint self.engine_url = engine_url self.database = database @@ -217,6 +130,7 @@ def __init__( user_drivers = additional_parameters.get("user_drivers", []) user_clients = additional_parameters.get("user_clients", []) self._client = AsyncClient( + account_name=account_name, auth=auth, base_url=engine_url, api_endpoint=api_endpoint, @@ -224,7 +138,7 @@ def __init__( transport=transport, headers={"User-Agent": get_user_agent_header(user_drivers, user_clients)}, ) - super().__init__() + self._system_engine_connection = system_engine_connection def cursor(self, **kwargs: Any) -> Cursor: if self.closed: @@ -256,6 +170,9 @@ async def aclose(self) -> None: await self._client.aclose() self._is_closed = True + if self._system_engine_connection: + await self._system_engine_connection.aclose() + async def __aexit__( self, exc_type: type, exc_val: Exception, exc_tb: TracebackType ) -> None: @@ -263,89 +180,84 @@ async def __aexit__( async def connect( - database: str = None, - username: Optional[str] = None, - password: Optional[str] = None, - access_token: Optional[str] = None, - auth: Auth = None, - engine_name: Optional[str] = None, - engine_url: Optional[str] = None, + auth: Optional[Auth] = None, account_name: Optional[str] = None, + database: Optional[str] = None, + engine_name: Optional[str] = None, api_endpoint: str = DEFAULT_API_URL, - use_token_cache: bool = True, additional_parameters: Dict[str, Any] = {}, ) -> Connection: - """Connect to Firebolt database. + """Connect to Firebolt. Args: + `auth` (Auth) Authentication object `database` (str): Name of the database to connect - `username` (Optional[str]): User name to use for authentication (Deprecated) - `password` (Optional[str]): Password to use for authentication (Deprecated) - `access_token` (Optional[str]): Authentication token to use instead of - credentials (Deprecated) - `auth` (Auth)L Authentication object. `engine_name` (Optional[str]): Name of the engine to connect to - `engine_url` (Optional[str]): The engine endpoint to use `account_name` (Optional[str]): For customers with multiple accounts; - if none, default is used `api_endpoint` (str): Firebolt API endpoint. Used for authentication - `use_token_cache` (bool): Cached authentication token in filesystem - Default: True `additional_parameters` (Optional[Dict]): Dictionary of less widely-used arguments for connection - - Note: - Providing both `engine_name` and `engine_url` will result in an error - """ # These parameters are optional in function signature # but are required to connect. # PEP 249 recommends making them kwargs. - if not database: - raise ConfigurationError("database name is required to connect.") + for name, value in (("auth", auth), ("account_name", account_name)): + if not value: + raise ConfigurationError(f"{name} is required to connect.") - validate_engine_name_and_url(engine_name, engine_url) + # Type checks + assert auth is not None + assert account_name is not None - if not auth: - if any([username, password, access_token, api_endpoint, use_token_cache]): - logger.warning(AUTH_CREDENTIALS_DEPRECATION_MESSAGE) - auth = _get_auth(username, password, access_token, use_token_cache) - else: - raise ConfigurationError("No authentication provided.") api_endpoint = fix_url_schema(api_endpoint) - # Mypy checks, this should never happen - assert database is not None + system_engine_url = fix_url_schema( + await _get_system_engine_url(auth, account_name, api_endpoint) + ) + # Don't use context manager since this will be stored + # and used in a resulting connection + system_engine_connection = Connection( + system_engine_url, + database, + auth, + account_name, + None, + api_endpoint, + additional_parameters, + ) - if not engine_name and not engine_url: - engine_url = await _get_database_default_engine_url( - database=database, - auth=auth, - account_name=account_name, - api_endpoint=api_endpoint, - ) + if not engine_name: + return system_engine_connection - elif engine_name: - engine_url = await _resolve_engine_url( - engine_name=engine_name, - auth=auth, - account_name=account_name, - api_endpoint=api_endpoint, - ) - elif account_name: - # In above if branches account name is validated since it's used to - # resolve or get an engine url. - # We need to manually validate account_name if none of the above - # cases are triggered. - async with AsyncClient( - auth=auth, - base_url=api_endpoint, - account_name=account_name, - api_endpoint=api_endpoint, - ) as client: - await client.account_id + else: + try: + engine_url, status, attached_db = await _get_engine_url_status_db( + system_engine_connection, engine_name + ) - assert engine_url is not None + if status != "Running": + raise EngineNotRunningError(engine_name) - engine_url = fix_url_schema(engine_url) - return Connection(engine_url, database, auth, api_endpoint, additional_parameters) + if database is not None and database != attached_db: + raise InterfaceError( + f"Engine {engine_name} is not attached to {database}, " + f"but to {attached_db}" + ) + elif database is None: + database = attached_db + + assert engine_url is not None + + engine_url = fix_url_schema(engine_url) + return Connection( + engine_url, + database, + auth, + account_name, + system_engine_connection, + api_endpoint, + additional_parameters, + ) + except: # noqa + await system_engine_connection.aclose() + raise diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index 890e71674a7..850039ac451 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -8,7 +8,6 @@ from typing import ( TYPE_CHECKING, Any, - Callable, Iterator, List, Optional, @@ -17,10 +16,11 @@ Union, ) -from aiorwlock import RWLock from httpx import Response, codes +from tricycle import RWLock from firebolt.async_db.util import is_db_available, is_engine_running +from firebolt.client import AsyncClient from firebolt.common._types import ( ColType, Column, @@ -30,59 +30,29 @@ split_format_sql, ) from firebolt.common.base_cursor import ( + JSON_OUTPUT_FORMAT, BaseCursor, CursorState, QueryStatus, Statistics, + check_not_closed, + check_query_executed, ) from firebolt.utils.exception import ( AsyncExecutionUnavailableError, - CursorClosedError, EngineNotRunningError, FireboltDatabaseError, OperationalError, ProgrammingError, - QueryNotRunError, ) if TYPE_CHECKING: from firebolt.async_db.connection import Connection -from httpx import AsyncClient as AsyncHttpxClient logger = logging.getLogger(__name__) -JSON_OUTPUT_FORMAT = "JSON_Compact" - - -def check_not_closed(func: Callable) -> Callable: - """(Decorator) ensure cursor is not closed before calling method.""" - - @wraps(func) - def inner(self: Cursor, *args: Any, **kwargs: Any) -> Any: - if self.closed: - raise CursorClosedError(method_name=func.__name__) - return func(self, *args, **kwargs) - - return inner - - -def check_query_executed(func: Callable) -> Callable: - """ - (Decorator) ensure that some query has been executed before - calling cursor method. - """ - - @wraps(func) - def inner(self: Cursor, *args: Any, **kwargs: Any) -> Any: - if self._state == CursorState.NONE: - raise QueryNotRunError(method_name=func.__name__) - return func(self, *args, **kwargs) - - return inner - - class Cursor(BaseCursor): """ Executes async queries to Firebolt Database. @@ -103,7 +73,7 @@ class Cursor(BaseCursor): def __init__( self, *args: Any, - client: AsyncHttpxClient, + client: AsyncClient, connection: Connection, **kwargs: Any, ) -> None: @@ -119,7 +89,9 @@ async def _raise_if_error(self, resp: Response) -> None: f"Error executing query:\n{resp.read().decode('utf-8')}" ) if resp.status_code == codes.FORBIDDEN: - if not await is_db_available(self.connection, self.connection.database): + if self.connection.database and not await is_db_available( + self.connection, self.connection.database + ): raise FireboltDatabaseError( f"Database {self.connection.database} does not exist" ) @@ -137,10 +109,10 @@ async def _raise_if_error(self, resp: Response) -> None: async def _api_request( self, - query: Optional[str] = "", - parameters: Optional[dict[str, Any]] = {}, - path: Optional[str] = "", - use_set_parameters: Optional[bool] = True, + query: str = "", + parameters: dict[str, Any] = {}, + path: str = "", + use_set_parameters: bool = True, ) -> Response: """ Query API, return Response object. @@ -159,13 +131,14 @@ async def _api_request( """ if use_set_parameters: parameters = {**(self._set_parameters or {}), **(parameters or {})} + if self.connection.database: + parameters["database"] = self.connection.database + if self.connection._is_system: + parameters["account_id"] = await self._client.account_id return await self._client.request( - url=f"/{path}", + url=f"/{path}" if path else "", method="POST", - params={ - "database": self.connection.database, - **(parameters or dict()), - }, + params=parameters, content=query, ) @@ -350,24 +323,6 @@ async def executemany( def __aiter__(self) -> Cursor: return self - def close(self) -> None: - """Terminate an ongoing query (if any) and mark connection as closed.""" - self._state = CursorState.CLOSED - self.connection._remove_cursor(self) - - def __del__(self) -> None: - self.close() - - # Context manager support - @check_not_closed - def __enter__(self) -> Cursor: - return self - - def __exit__( - self, exc_type: type, exc_val: Exception, exc_tb: TracebackType - ) -> None: - self.close() - # TODO: figure out how to implement __aenter__ and __await__ @check_not_closed def __aenter__(self) -> Cursor: @@ -429,13 +384,13 @@ async def cancel(self, query_id: str) -> None: @wraps(BaseCursor.fetchone) async def fetchone(self) -> Optional[List[ColType]]: - async with self._async_query_lock.reader: + async with self._async_query_lock.read_locked(): """Fetch the next row of a query result set.""" return super().fetchone() @wraps(BaseCursor.fetchmany) async def fetchmany(self, size: Optional[int] = None) -> List[List[ColType]]: - async with self._async_query_lock.reader: + async with self._async_query_lock.read_locked(): """ Fetch the next set of rows of a query result; size is cursor.arraysize by default. @@ -444,11 +399,15 @@ async def fetchmany(self, size: Optional[int] = None) -> List[List[ColType]]: @wraps(BaseCursor.fetchall) async def fetchall(self) -> List[List[ColType]]: - async with self._async_query_lock.reader: + async with self._async_query_lock.read_locked(): """Fetch all remaining rows of a query result.""" return super().fetchall() @wraps(BaseCursor.nextset) async def nextset(self) -> None: - async with self._async_query_lock.reader: + async with self._async_query_lock.read_locked(): return super().nextset() + + @check_not_closed + def __enter__(self) -> Cursor: + return self diff --git a/src/firebolt/async_db/util.py b/src/firebolt/async_db/util.py index ab560c75379..180c3efa6a6 100644 --- a/src/firebolt/async_db/util.py +++ b/src/firebolt/async_db/util.py @@ -1,14 +1,24 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Tuple -from httpx import URL, Response +from httpx import URL, Timeout, codes -from firebolt.utils.urls import DATABASES_URL, ENGINES_URL +from firebolt.client import AsyncClient +from firebolt.client.auth import Auth +from firebolt.common.settings import DEFAULT_TIMEOUT_SECONDS +from firebolt.utils.exception import ( + AccountNotFoundError, + FireboltEngineError, + InterfaceError, +) +from firebolt.utils.urls import DYNAMIC_QUERY, GATEWAY_HOST_BY_ACCOUNT_NAME if TYPE_CHECKING: from firebolt.async_db.connection import Connection +ENGINE_STATUS_RUNNING = "Running" + async def is_db_available(connection: Connection, database_name: str) -> bool: """ @@ -16,11 +26,19 @@ async def is_db_available(connection: Connection, database_name: str) -> bool: Args: connection (firebolt.async_db.connection.Connection) + database_name (str): Name of a database """ - resp = await _filter_request( - connection, DATABASES_URL, {"filter.name_contains": database_name} - ) - return len(resp.json()["edges"]) > 0 + system_engine = connection._system_engine_connection or connection + with system_engine.cursor() as cursor: + return ( + await cursor.execute( + """ + SELECT 1 FROM information_schema.databases WHERE database_name=? + """, + [database_name], + ) + > 0 + ) async def is_engine_running(connection: Connection, engine_url: str) -> bool: @@ -29,29 +47,58 @@ async def is_engine_running(connection: Connection, engine_url: str) -> bool: Args: connection (firebolt.async_db.connection.Connection): connection. + engine_url (str): URL of the engine """ - # Url is not guaranteed to be of this structure, - # but for the sake of error checking this is sufficient. + + if connection._is_system: + # System engine is always running + return True + engine_name = URL(engine_url).host.split(".")[0].replace("-", "_") - resp = await _filter_request( - connection, - ENGINES_URL, - { - "filter.name_contains": engine_name, - "filter.current_status_eq": "ENGINE_STATUS_RUNNING_REVISION_SERVING", - }, + assert connection._system_engine_connection is not None # Type check + _, status, _ = await _get_engine_url_status_db( + connection._system_engine_connection, engine_name ) - return len(resp.json()["edges"]) > 0 + return status == ENGINE_STATUS_RUNNING -async def _filter_request( - connection: Connection, endpoint: str, filters: dict -) -> Response: - resp = await connection._client.request( - # Full url overrides the client url, which contains engine as a prefix. - url=connection.api_endpoint + endpoint, - method="GET", - params=filters, - ) - resp.raise_for_status() - return resp +async def _get_system_engine_url( + auth: Auth, + account_name: str, + api_endpoint: str, +) -> str: + async with AsyncClient( + auth=auth, + base_url=api_endpoint, + account_name=account_name, + api_endpoint=api_endpoint, + timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), + ) as client: + url = GATEWAY_HOST_BY_ACCOUNT_NAME.format(account_name=account_name) + response = await client.get(url=url) + if response.status_code == codes.NOT_FOUND: + raise AccountNotFoundError(account_name) + if response.status_code != codes.OK: + raise InterfaceError( + f"Unable to retrieve system engine endpoint {url}: " + f"{response.status_code} {response.content}" + ) + return response.json()["engineUrl"] + DYNAMIC_QUERY + + +async def _get_engine_url_status_db( + system_engine: Connection, engine_name: str +) -> Tuple[str, str, str]: + with system_engine.cursor() as cursor: + await cursor.execute( + """ + SELECT url, attached_to, status FROM information_schema.engines + WHERE engine_name=? + """, + [engine_name], + ) + row = await cursor.fetchone() + if row is None: + raise FireboltEngineError(f"Engine with name {engine_name} doesn't exist") + engine_url, database, status = row + return str(engine_url), str(status), str(database) # Mypy check diff --git a/src/firebolt/client/auth/__init__.py b/src/firebolt/client/auth/__init__.py index 90e4d4149a5..51c9bcdc95d 100644 --- a/src/firebolt/client/auth/__init__.py +++ b/src/firebolt/client/auth/__init__.py @@ -1,5 +1,2 @@ from firebolt.client.auth.base import Auth -from firebolt.client.auth.service_account import ServiceAccount -from firebolt.client.auth.token import Token -from firebolt.client.auth.username_password import UsernamePassword -from firebolt.client.auth.utils import _get_auth +from firebolt.client.auth.client_credentials import ClientCredentials diff --git a/src/firebolt/client/auth/service_account.py b/src/firebolt/client/auth/client_credentials.py similarity index 88% rename from src/firebolt/client/auth/service_account.py rename to src/firebolt/client/auth/client_credentials.py index a9181b82e72..6370a2b49f2 100644 --- a/src/firebolt/client/auth/service_account.py +++ b/src/firebolt/client/auth/client_credentials.py @@ -7,7 +7,7 @@ from firebolt.utils.util import cached_property -class ServiceAccount(_RequestBasedAuth): +class ClientCredentials(_RequestBasedAuth): """Service Account authentication class for Firebolt Database. Gets authentication token using @@ -45,13 +45,15 @@ def __init__( self.client_secret = client_secret super().__init__(use_token_cache) - def copy(self) -> "ServiceAccount": + def copy(self) -> "ClientCredentials": """Make another auth object with same credentials. Returns: - ServiceAccount: Auth object + ClientCredentials: Auth object """ - return ServiceAccount(self.client_id, self.client_secret, self._use_token_cache) + return ClientCredentials( + self.client_id, self.client_secret, self._use_token_cache + ) @cached_property def _token_storage(self) -> Optional[TokenSecureStorage]: @@ -84,6 +86,7 @@ def _make_auth_request(self) -> AuthRequest: "client_id": self.client_id, "client_secret": self.client_secret, "grant_type": "client_credentials", + "audience": "https://api.firebolt.io", }, ) return response diff --git a/src/firebolt/client/auth/token.py b/src/firebolt/client/auth/token.py deleted file mode 100644 index 01864a1cc09..00000000000 --- a/src/firebolt/client/auth/token.py +++ /dev/null @@ -1,41 +0,0 @@ -from typing import Generator - -from httpx import Request, Response - -from firebolt.client.auth import Auth -from firebolt.utils.exception import AuthorizationError - - -class Token(Auth): - """Token authentication class for Firebolt Database. - - Uses provided token for authentication. Doesn't cache token and doesn't - refresh it on expiration. - - Args: - token (str): Authorization token - - Attributes: - token (str): - """ - - def __init__(self, token: str): - super().__init__(use_token_cache=False) - self._token = token - - def copy(self) -> "Token": - """Make another auth object with same credentials. - - Returns: - Token: Auth object - """ - assert self.token - return Token(self.token) - - def get_new_token_generator(self) -> Generator[Request, Response, None]: - """Raise authorization error since token is invalid or expired. - - Raises: - AuthorizationError: Token is invalid or expired - """ - raise AuthorizationError("Provided token in not valid anymore.") diff --git a/src/firebolt/client/auth/username_password.py b/src/firebolt/client/auth/username_password.py deleted file mode 100644 index 21fbd31936d..00000000000 --- a/src/firebolt/client/auth/username_password.py +++ /dev/null @@ -1,85 +0,0 @@ -from typing import Optional - -from firebolt.client.auth.base import AuthRequest -from firebolt.client.auth.request_auth_base import _RequestBasedAuth -from firebolt.utils.token_storage import TokenSecureStorage -from firebolt.utils.urls import AUTH_URL -from firebolt.utils.util import cached_property - - -class UsernamePassword(_RequestBasedAuth): - """Username/Password authentication class for Firebolt Database. - - Gets authentication token using - provided credentials and updates it when it expires. - - Args: - username (str): Username - password (str): Password - use_token_cache (bool): True if token should be cached in filesystem; - False otherwise - - Attributes: - username (str): Username - password (str): Password - """ - - __slots__ = ( - "username", - "password", - "_token", - "_expires", - "_use_token_cache", - "_user_agent", - ) - - requires_response_body = True - - def __init__( - self, - username: str, - password: str, - use_token_cache: bool = True, - ): - self.username = username - self.password = password - super().__init__(use_token_cache) - - def copy(self) -> "UsernamePassword": - """Make another auth object with same credentials. - - Returns: - UsernamePassword: Auth object - """ - return UsernamePassword(self.username, self.password, self._use_token_cache) - - @cached_property - def _token_storage(self) -> Optional[TokenSecureStorage]: - """Token filesystem cache storage. - - This is evaluated lazily, only if caching is enabled - - Returns: - TokenSecureStorage: Token filesystem cache storage - """ - return TokenSecureStorage(username=self.username, password=self.password) - - def _make_auth_request(self) -> AuthRequest: - """Get new token using username and password. - - Yields: - Request: An http request to get token. Expects Response to be sent back - - Raises: - AuthenticationError: Error while authenticating with provided credentials - """ - response = self.request_class( - "POST", - AUTH_URL, - headers={ - "Content-Type": "application/json;charset=UTF-8", - "User-Agent": self._user_agent, - }, - json={"username": self.username, "password": self.password}, - ) - return response diff --git a/src/firebolt/client/auth/utils.py b/src/firebolt/client/auth/utils.py deleted file mode 100644 index f6bb6b09374..00000000000 --- a/src/firebolt/client/auth/utils.py +++ /dev/null @@ -1,37 +0,0 @@ -from typing import Optional - -from firebolt.client.auth import Auth, Token, UsernamePassword -from firebolt.utils.exception import ConfigurationError - - -def _get_auth( - username: Optional[str], - password: Optional[str], - access_token: Optional[str], - use_token_cache: bool, -) -> Auth: - """Create `Auth` class based on provided credentials. - - If `access_token` is provided, it's used for `Auth` creation. - Otherwise, username/password are used. - - Returns: - Auth: `auth object` - - Raises: - `ConfigurationError`: Invalid combination of credentials provided - - """ - if not access_token: - if not username or not password: - raise ConfigurationError( - "Neither username/password nor access_token are provided. Provide one" - " to authenticate." - ) - return UsernamePassword(username, password, use_token_cache) - if username or password: - raise ConfigurationError( - "Username/password and access_token are both provided. Provide only one" - " to authenticate." - ) - return Token(access_token) diff --git a/src/firebolt/client/client.py b/src/firebolt/client/client.py index f6e378a595f..c341e82cc01 100644 --- a/src/firebolt/client/client.py +++ b/src/firebolt/client/client.py @@ -1,6 +1,7 @@ from typing import Any, Optional from anyio._core._eventloop import get_asynclib +from async_property import async_cached_property # type: ignore from httpx import URL from httpx import AsyncClient as HttpxAsyncClient from httpx import Client as HttpxClient @@ -12,10 +13,11 @@ from firebolt.client.auth.base import AuthRequest from firebolt.client.constants import DEFAULT_API_URL from firebolt.utils.exception import AccountNotFoundError -from firebolt.utils.urls import ACCOUNT_BY_NAME_URL, ACCOUNT_URL +from firebolt.utils.urls import ACCOUNT_BY_NAME_URL from firebolt.utils.util import ( cached_property, fix_url_schema, + get_auth_endpoint, merge_urls, mixin_for, ) @@ -38,16 +40,17 @@ class FireboltClientMixin(FireboltClientMixinBase): def __init__( self, *args: Any, - account_name: Optional[str] = None, + account_name: str, + auth: Auth, api_endpoint: str = DEFAULT_API_URL, - auth: Optional[Auth] = None, **kwargs: Any, ): self.account_name = account_name self._api_endpoint = URL(fix_url_schema(api_endpoint)) + self._auth_endpoint = get_auth_endpoint(self._api_endpoint) super().__init__(*args, auth=auth, **kwargs) - def _build_auth(self, auth: Optional[AuthTypes]) -> Optional[Auth]: + def _build_auth(self, auth: Optional[AuthTypes]) -> Auth: """Create Auth object based on auth provided. Overrides ``httpx.Client._build_auth`` @@ -61,16 +64,21 @@ def _build_auth(self, auth: Optional[AuthTypes]) -> Optional[Auth]: Raises: TypeError: Auth argument has unsupported type """ - if auth is None or isinstance(auth, Auth): - return auth - raise TypeError(f'Invalid "auth" argument: {auth!r}') + if not (auth is None or isinstance(auth, Auth)): + raise TypeError(f'Invalid "auth" argument: {auth!r}') + assert auth is not None # type check + return auth def _merge_auth_request(self, request: Request) -> Request: if isinstance(request, AuthRequest): - request.url = merge_urls(self._api_endpoint, request.url) + request.url = merge_urls(self._auth_endpoint, request.url) request._prepare(dict(request.headers)) return request + def _enforce_trailing_slash(self, url: URL) -> URL: + """Don't automatically append trailing slach to a base url""" + return url + class Client(FireboltClientMixin, HttpxClient): """An HTTP client, based on httpx.Client. @@ -93,18 +101,16 @@ def account_id(self) -> str: Raises: AccountNotFoundError: No account found with provided name """ - if self.account_name: - response = self.get( - url=ACCOUNT_BY_NAME_URL, params={"account_name": self.account_name} + response = self.get( + url=self._api_endpoint.copy_with( + path=ACCOUNT_BY_NAME_URL.format(account_name=self.account_name) ) - if response.status_code == HttpxCodes.NOT_FOUND: - raise AccountNotFoundError(self.account_name) - # process all other status codes - response.raise_for_status() - return response.json()["account_id"] - - # account_name isn't set, use the default account. - return self.get(url=ACCOUNT_URL).json()["account"]["id"] + ) + if response.status_code == HttpxCodes.NOT_FOUND: + raise AccountNotFoundError(self.account_name) + # process all other status codes + response.raise_for_status() + return response.json()["id"] def _send_handling_redirects( self, request: Request, *args: Any, **kwargs: Any @@ -122,7 +128,7 @@ class AsyncClient(FireboltClientMixin, HttpxAsyncClient): FireboltAuth instance. """ - @cached_property + @async_cached_property async def account_id(self) -> str: """User account id. @@ -135,18 +141,16 @@ async def account_id(self) -> str: Raises: AccountNotFoundError: No account found with provided name """ - if self.account_name: - response = await self.get( - url=ACCOUNT_BY_NAME_URL, params={"account_name": self.account_name} + response = await self.get( + url=self._api_endpoint.copy_with( + path=ACCOUNT_BY_NAME_URL.format(account_name=self.account_name) ) - if response.status_code == HttpxCodes.NOT_FOUND: - raise AccountNotFoundError(self.account_name) - # process all other status codes - response.raise_for_status() - return response.json()["account_id"] - - # account_name isn't set; use the default account. - return (await self.get(url=ACCOUNT_URL)).json()["account"]["id"] + ) + if response.status_code == HttpxCodes.NOT_FOUND: + raise AccountNotFoundError(self.account_name) + # process all other status codes + response.raise_for_status() + return response.json()["id"] async def _send_handling_redirects( self, request: Request, *args: Any, **kwargs: Any diff --git a/src/firebolt/common/base_connection.py b/src/firebolt/common/base_connection.py index 1d47374015c..f66993e52e1 100644 --- a/src/firebolt/common/base_connection.py +++ b/src/firebolt/common/base_connection.py @@ -1,5 +1,6 @@ from typing import Any, List +from firebolt.common.base_cursor import BaseCursor from firebolt.utils.exception import ConnectionClosedError @@ -8,13 +9,18 @@ def __init__(self) -> None: self._cursors: List[Any] = [] self._is_closed = False - def _remove_cursor(self, cursor: Any) -> None: + def _remove_cursor(self, cursor: BaseCursor) -> None: # This way it's atomic try: self._cursors.remove(cursor) except ValueError: pass + @property + def _is_system(self) -> bool: + """`True` if connection is a system engine connection; `False` otherwise.""" + return self._system_engine_connection is None # type: ignore + @property def closed(self) -> bool: """`True` if connection is closed; `False` otherwise.""" diff --git a/src/firebolt/common/base_cursor.py b/src/firebolt/common/base_cursor.py index eb288998116..b0925bfc001 100644 --- a/src/firebolt/common/base_cursor.py +++ b/src/firebolt/common/base_cursor.py @@ -3,6 +3,7 @@ import logging from enum import Enum from functools import wraps +from types import TracebackType from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union from httpx import Response @@ -363,3 +364,16 @@ def setinputsizes(self, sizes: List[int]) -> None: @check_not_closed def setoutputsize(self, size: int, column: Optional[int] = None) -> None: """Set a column buffer size for fetches of large columns (does nothing).""" + + def close(self) -> None: + """Terminate an ongoing query (if any) and mark connection as closed.""" + self._state = CursorState.CLOSED + self.connection._remove_cursor(self) # type:ignore + + def __del__(self) -> None: + self.close() + + def __exit__( + self, exc_type: type, exc_val: Exception, exc_tb: TracebackType + ) -> None: + self.close() diff --git a/src/firebolt/common/settings.py b/src/firebolt/common/settings.py index 663b17a7fc4..2c0c3ead742 100644 --- a/src/firebolt/common/settings.py +++ b/src/firebolt/common/settings.py @@ -1,9 +1,9 @@ import logging import os from dataclasses import dataclass, field -from typing import Any, Callable, Optional +from typing import Any, Callable -from firebolt.client.auth import Auth, UsernamePassword +from firebolt.client.auth import Auth, ClientCredentials logger = logging.getLogger(__name__) @@ -12,20 +12,8 @@ KEEPIDLE_RATE: int = 60 # seconds DEFAULT_TIMEOUT_SECONDS: int = 60 -AUTH_CREDENTIALS_DEPRECATION_MESSAGE = """ Passing connection credentials directly in Settings is deprecated. - Use Auth object instead. - Examples: - >>> from firebolt.client.auth import UsernamePassword - >>> ... - >>> settings = Settings(auth=UsernamePassword(username, password), ...) - or - >>> from firebolt.client.auth import Token - >>> ... - >>> settings = Settings(auth=Token(access_token), ...)""" - -USERNAME_ENV = "FIREBOLT_USER" -PASSWORD_ENV = "FIREBOLT_PASSWORD" -AUTH_TOKEN_ENV = "FIREBOLT_AUTH_TOKEN" +CLIENT_ID_ENV = "FIREBOLT_CLIENT_ID" +CLIENT_SECRET_ENV = "FIREBOLT_CLIENT_SECRET" ACCOUNT_ENV = "FIREBOLT_ACCOUNT" SERVER_ENV = "FIREBOLT_SERVER" DEFAULT_REGION_ENV = "FIREBOLT_DEFAULT_REGION" @@ -38,12 +26,12 @@ def inner() -> Any: return inner -def auth_from_env() -> Optional[Auth]: - username = os.environ.get(USERNAME_ENV, None) - password = os.environ.get(PASSWORD_ENV, None) - if username and password: - return UsernamePassword(username, password) - return None +def auth_from_env() -> Auth: + client_id = os.environ.get(CLIENT_ID_ENV, None) + client_secret = os.environ.get(CLIENT_SECRET_ENV, None) + if client_id and client_secret: + return ClientCredentials(client_id, client_secret) + raise ValueError("Auth not provided") @dataclass @@ -62,41 +50,8 @@ class Settings: default_region (str): Default region for provisioning """ - auth: Optional[Auth] = field(default_factory=auth_from_env) - # Authorization - user: Optional[str] = field(default=None) - password: Optional[str] = field(default=None) - # Or - access_token: Optional[str] = field(default_factory=from_env(AUTH_TOKEN_ENV)) + auth: Auth = field(default_factory=auth_from_env) - account_name: Optional[str] = field(default_factory=from_env(ACCOUNT_ENV)) + account_name: str = field(default_factory=from_env(ACCOUNT_ENV)) server: str = field(default_factory=from_env(SERVER_ENV)) default_region: str = field(default_factory=from_env(DEFAULT_REGION_ENV)) - use_token_cache: bool = field(default=True) - - def __post_init__(self) -> None: - """Validate that either creds or token is provided. - - Args: - values (dict): settings initial values - - Returns: - dict: Validated settings values - - Raises: - ValueError: Either both or none of credentials and token are provided - """ - - params_present = ( - self.user is not None or self.password is not None, - self.access_token is not None, - self.auth is not None, - ) - if sum(params_present) == 0: - raise ValueError( - "Provide at least one of auth, user/password or access_token." - ) - if sum(params_present) > 1: - raise ValueError("Provide only one of auth, user/password or access_token") - if any((self.user, self.password, self.access_token)): - logger.warning(AUTH_CREDENTIALS_DEPRECATION_MESSAGE) diff --git a/src/firebolt/db/connection.py b/src/firebolt/db/connection.py index 3a29d7852e5..adb8bf46e71 100644 --- a/src/firebolt/db/connection.py +++ b/src/firebolt/db/connection.py @@ -2,119 +2,37 @@ import logging import socket -from json import JSONDecodeError from types import TracebackType from typing import Any, Dict, List, Optional from warnings import warn from httpcore.backends.base import NetworkStream from httpcore.backends.sync import SyncBackend -from httpx import HTTPStatusError, HTTPTransport, RequestError, Timeout +from httpx import HTTPTransport, Timeout from readerwriterlock.rwlock import RWLockWrite from firebolt.client import DEFAULT_API_URL, Client -from firebolt.client.auth import Auth, _get_auth +from firebolt.client.auth import Auth from firebolt.common.base_connection import BaseConnection from firebolt.common.settings import ( - AUTH_CREDENTIALS_DEPRECATION_MESSAGE, DEFAULT_TIMEOUT_SECONDS, KEEPALIVE_FLAG, KEEPIDLE_RATE, ) -from firebolt.common.util import validate_engine_name_and_url from firebolt.db.cursor import Cursor +from firebolt.db.util import _get_engine_url_status_db, _get_system_engine_url from firebolt.utils.exception import ( ConfigurationError, ConnectionClosedError, - FireboltEngineError, + EngineNotRunningError, InterfaceError, ) -from firebolt.utils.urls import ( - ACCOUNT_ENGINE_ID_BY_NAME_URL, - ACCOUNT_ENGINE_URL, - ACCOUNT_ENGINE_URL_BY_DATABASE_NAME, -) from firebolt.utils.usage_tracker import get_user_agent_header from firebolt.utils.util import fix_url_schema logger = logging.getLogger(__name__) -def _resolve_engine_url( - engine_name: str, - auth: Auth, - api_endpoint: str, - account_name: Optional[str] = None, -) -> str: - with Client( - auth=auth, - base_url=api_endpoint, - account_name=account_name, - api_endpoint=api_endpoint, - timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), - ) as client: - account_id = client.account_id - url = ACCOUNT_ENGINE_ID_BY_NAME_URL.format(account_id=account_id) - try: - response = client.get( - url=url, - params={"engine_name": engine_name}, - ) - response.raise_for_status() - engine_id = response.json()["engine_id"]["engine_id"] - url = ACCOUNT_ENGINE_URL.format(account_id=account_id, engine_id=engine_id) - response = client.get(url=url) - response.raise_for_status() - return response.json()["engine"]["endpoint"] - except HTTPStatusError as e: - # Engine error would be 404. - if e.response.status_code != 404: - raise InterfaceError( - f"Error {e.__class__.__name__}: Unable to retrieve engine " - f"endpoint {url}." - ) - # Once this is point is reached we've already authenticated with - # the backend so it's safe to assume the cause of the error is - # missing engine. - raise FireboltEngineError(f"Firebolt engine {engine_name} does not exist.") - except (JSONDecodeError, RequestError, RuntimeError) as e: - raise InterfaceError( - f"Error {e.__class__.__name__}: " - f"Unable to retrieve engine endpoint {url}." - ) - - -def _get_database_default_engine_url( - database: str, - auth: Auth, - api_endpoint: str, - account_name: Optional[str] = None, -) -> str: - with Client( - auth=auth, - base_url=api_endpoint, - account_name=account_name, - api_endpoint=api_endpoint, - timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), - ) as client: - try: - account_id = client.account_id - response = client.get( - url=ACCOUNT_ENGINE_URL_BY_DATABASE_NAME.format(account_id=account_id), - params={"database_name": database}, - ) - response.raise_for_status() - return response.json()["engine_url"] - except ( - JSONDecodeError, - RequestError, - RuntimeError, - HTTPStatusError, - KeyError, - ) as e: - raise InterfaceError(f"Unable to retrieve default engine endpoint: {e}.") - - class OverriddenHttpBackend(SyncBackend): """ `OverriddenHttpBackend` is a short-term solution for the TCP @@ -174,7 +92,6 @@ class Connection(BaseConnection): are not implemented. """ - client_class: type __slots__ = ( "_client", "_cursors", @@ -183,16 +100,20 @@ class Connection(BaseConnection): "api_endpoint", "_is_closed", "_closing_lock", + "_system_engine_connection", ) def __init__( self, engine_url: str, - database: str, + database: Optional[str], auth: Auth, + account_name: str, + system_engine_connection: Optional["Connection"], api_endpoint: str = DEFAULT_API_URL, additional_parameters: Dict[str, Any] = {}, ): + super().__init__() self.api_endpoint = api_endpoint self.engine_url = engine_url self.database = database @@ -203,6 +124,7 @@ def __init__( user_drivers = additional_parameters.get("user_drivers", []) user_clients = additional_parameters.get("user_clients", []) self._client = Client( + account_name=account_name, auth=auth, base_url=engine_url, api_endpoint=api_endpoint, @@ -210,10 +132,10 @@ def __init__( transport=transport, headers={"User-Agent": get_user_agent_header(user_drivers, user_clients)}, ) + self._system_engine_connection = system_engine_connection # Holding this lock for write means that connection is closing itself. # cursor() should hold this lock for read to read/write state self._closing_lock = RWLockWrite() - super().__init__() def cursor(self, **kwargs: Any) -> Cursor: if self.closed: @@ -224,13 +146,6 @@ def cursor(self, **kwargs: Any) -> Cursor: self._cursors.append(c) return c - def _remove_cursor(self, cursor: Cursor) -> None: - # This way it's atomic - try: - self._cursors.remove(cursor) - except ValueError: - pass - def close(self) -> None: if self.closed: return @@ -247,6 +162,9 @@ def close(self) -> None: self._client.close() self._is_closed = True + if self._system_engine_connection: + self._system_engine_connection.close() + # Context manager support def __enter__(self) -> Connection: if self.closed: @@ -260,93 +178,89 @@ def __exit__( def __del__(self) -> None: if not self.closed: - warn(f"Unclosed {self!r}", UserWarning) + warn(f"Unclosed {self!r} {id(self)}", UserWarning) def connect( - database: str = None, - username: Optional[str] = None, - password: Optional[str] = None, - access_token: Optional[str] = None, - auth: Auth = None, - engine_name: Optional[str] = None, - engine_url: Optional[str] = None, + auth: Optional[Auth] = None, account_name: Optional[str] = None, + database: Optional[str] = None, + engine_name: Optional[str] = None, api_endpoint: str = DEFAULT_API_URL, - use_token_cache: bool = True, additional_parameters: Dict[str, Any] = {}, ) -> Connection: - """Connect to Firebolt database. + """Connect to Firebolt. Args: + `auth` (Auth) Authentication object `database` (str): Name of the database to connect - `username` (Optional[str]): User name to use for authentication (Deprecated) - `password` (Optional[str]): Password to use for authentication (Deprecated) - `access_token` (Optional[str]): Authentication token to use instead of - credentials (Deprecated) - `auth` (Auth)L Authentication object. `engine_name` (Optional[str]): Name of the engine to connect to - `engine_url` (Optional[str]): The engine endpoint to use `account_name` (Optional[str]): For customers with multiple accounts; if none, default is used `api_endpoint` (str): Firebolt API endpoint. Used for authentication - `use_token_cache` (bool): Cached authentication token in filesystem - Default: True `additional_parameters` (Optional[Dict]): Dictionary of less widely-used arguments for connection - Note: - Providing both `engine_name` and `engine_url` will result in an error - """ # These parameters are optional in function signature # but are required to connect. # PEP 249 recommends making them kwargs. - if not database: - raise ConfigurationError("database name is required to connect.") + for name, value in (("auth", auth), ("account_name", account_name)): + if not value: + raise ConfigurationError(f"{name} is required to connect.") - validate_engine_name_and_url(engine_name, engine_url) + # Type checks + assert auth is not None + assert account_name is not None - if not auth: - if any([username, password, access_token, api_endpoint, use_token_cache]): - logger.warning(AUTH_CREDENTIALS_DEPRECATION_MESSAGE) - auth = _get_auth(username, password, access_token, use_token_cache) - else: - raise ConfigurationError("No authentication provided.") api_endpoint = fix_url_schema(api_endpoint) - # Mypy checks, this should never happen - assert database is not None - - if not engine_name and not engine_url: - engine_url = _get_database_default_engine_url( - database=database, - auth=auth, - account_name=account_name, - api_endpoint=api_endpoint, - ) + system_engine_url = fix_url_schema( + _get_system_engine_url(auth, account_name, api_endpoint) + ) + # Don't use context manager since this will be stored + # and used in a resulting connection + system_engine_connection = Connection( + system_engine_url, + database, + auth, + account_name, + None, + api_endpoint, + additional_parameters, + ) + if not engine_name: + return system_engine_connection - elif engine_name: - engine_url = _resolve_engine_url( - engine_name=engine_name, - auth=auth, - account_name=account_name, - api_endpoint=api_endpoint, - ) - elif account_name: - # In above if branches account name is validated since it's used to - # resolve or get an engine url. - # We need to manually validate account_name if none of the above - # cases are triggered. - with Client( - auth=auth, - base_url=api_endpoint, - account_name=account_name, - api_endpoint=api_endpoint, - ) as client: - client.account_id + else: + try: + engine_url, status, attached_db = _get_engine_url_status_db( + system_engine_connection, engine_name + ) - assert engine_url is not None + if status != "Running": + raise EngineNotRunningError(engine_name) - engine_url = fix_url_schema(engine_url) - return Connection(engine_url, database, auth, api_endpoint, additional_parameters) + if database is not None and database != attached_db: + raise InterfaceError( + f"Engine {engine_name} is not attached to {database}, " + f"but to {attached_db}" + ) + elif database is None: + database = attached_db + + assert engine_url is not None + + engine_url = fix_url_schema(engine_url) + return Connection( + engine_url, + database, + auth, + account_name, + system_engine_connection, + api_endpoint, + additional_parameters, + ) + except: # noqa + system_engine_connection.close() + raise diff --git a/src/firebolt/db/cursor.py b/src/firebolt/db/cursor.py index 69c84a81115..c516d1e8e25 100644 --- a/src/firebolt/db/cursor.py +++ b/src/firebolt/db/cursor.py @@ -4,7 +4,6 @@ import re import time from threading import Lock -from types import TracebackType from typing import ( TYPE_CHECKING, Any, @@ -16,10 +15,10 @@ Union, ) -from httpx import Client as HttpxClient from httpx import Response, codes from readerwriterlock.rwlock import RWLockWrite +from firebolt.client import Client from firebolt.common._types import ( ColType, Column, @@ -72,7 +71,7 @@ class Cursor(BaseCursor): ) def __init__( - self, *args: Any, client: HttpxClient, connection: Connection, **kwargs: Any + self, *args: Any, client: Client, connection: Connection, **kwargs: Any ) -> None: super().__init__(*args, **kwargs) self._query_lock = RWLockWrite() @@ -87,7 +86,9 @@ def _raise_if_error(self, resp: Response) -> None: f"Error executing query:\n{resp.read().decode('utf-8')}" ) if resp.status_code == codes.FORBIDDEN: - if not is_db_available(self.connection, self.connection.database): + if self.connection.database and not is_db_available( + self.connection, self.connection.database + ): raise FireboltDatabaseError( f"Database {self.connection.database} does not exist" ) @@ -105,10 +106,10 @@ def _raise_if_error(self, resp: Response) -> None: def _api_request( self, - query: Optional[str] = "", - parameters: Optional[dict[str, Any]] = {}, - path: Optional[str] = "", - use_set_parameters: Optional[bool] = True, + query: str = "", + parameters: dict[str, Any] = {}, + path: str = "", + use_set_parameters: bool = True, ) -> Response: """ Query API, return Response object. @@ -127,13 +128,14 @@ def _api_request( """ if use_set_parameters: parameters = {**(self._set_parameters or {}), **(parameters or {})} + if self.connection.database: + parameters["database"] = self.connection.database + if self.connection._is_system: + parameters["account_id"] = self._client.account_id return self._client.request( - url=f"/{path}", + url=f"/{path}" if path else "", method="POST", - params={ - "database": self.connection.database, - **(parameters or dict()), - }, + params=parameters, content=query, ) @@ -341,24 +343,6 @@ def get_status(self, query_id: str) -> QueryStatus: return QueryStatus.NOT_READY return QueryStatus[resp_json["status"]] - def close(self) -> None: - """Terminate an ongoing query (if any) and mark connection as closed.""" - self._state = CursorState.CLOSED - self.connection._remove_cursor(self) - - def __del__(self) -> None: - self.close() - - # Context manager support - @check_not_closed - def __enter__(self) -> Cursor: - return self - - def __exit__( - self, exc_type: type, exc_val: Exception, exc_tb: TracebackType - ) -> None: - self.close() - @check_not_closed def cancel(self, query_id: str) -> None: """Cancel a server-side async query.""" @@ -377,3 +361,7 @@ def __iter__(self) -> Generator[List[ColType], None, None]: if row is None: return yield row + + @check_not_closed + def __enter__(self) -> Cursor: + return self diff --git a/src/firebolt/db/util.py b/src/firebolt/db/util.py index 94f24c3be7c..44a3a71de47 100644 --- a/src/firebolt/db/util.py +++ b/src/firebolt/db/util.py @@ -1,26 +1,44 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Tuple -from httpx import URL, Response +from httpx import URL, Timeout, codes -from firebolt.utils.urls import DATABASES_URL, ENGINES_URL +from firebolt.client import Client +from firebolt.client.auth import Auth +from firebolt.common.settings import DEFAULT_TIMEOUT_SECONDS +from firebolt.utils.exception import ( + AccountNotFoundError, + FireboltEngineError, + InterfaceError, +) +from firebolt.utils.urls import DYNAMIC_QUERY, GATEWAY_HOST_BY_ACCOUNT_NAME if TYPE_CHECKING: from firebolt.db.connection import Connection +ENGINE_STATUS_RUNNING = "Running" + def is_db_available(connection: Connection, database_name: str) -> bool: """ Verify that the database exists. Args: - connection (firebolt.async_db.connection.Connection) + connection (firebolt.db.connection.Connection) + database_name (str): Name of a database """ - resp = _filter_request( - connection, DATABASES_URL, {"filter.name_contains": database_name} - ) - return len(resp.json()["edges"]) > 0 + system_engine = connection._system_engine_connection or connection + with system_engine.cursor() as cursor: + return ( + cursor.execute( + """ + SELECT 1 FROM information_schema.databases WHERE database_name=? + """, + [database_name], + ) + > 0 + ) def is_engine_running(connection: Connection, engine_url: str) -> bool: @@ -28,28 +46,60 @@ def is_engine_running(connection: Connection, engine_url: str) -> bool: Verify that the engine is running. Args: - connection (firebolt.async_db.connection.Connection): connection. + connection (firebolt.db.connection.Connection): connection. + engine_url (str): URL of the engine """ - # Url is not guaranteed to be of this structure, - # but for the sake of error checking this is sufficient. + + if connection._is_system: + # System engine is always running + return True + engine_name = URL(engine_url).host.split(".")[0].replace("-", "_") - resp = _filter_request( - connection, - ENGINES_URL, - { - "filter.name_contains": engine_name, - "filter.current_status_eq": "ENGINE_STATUS_RUNNING_REVISION_SERVING", - }, + assert connection._system_engine_connection is not None # Type check + _, status, _ = _get_engine_url_status_db( + connection._system_engine_connection, engine_name ) - return len(resp.json()["edges"]) > 0 + return status == ENGINE_STATUS_RUNNING -def _filter_request(connection: Connection, endpoint: str, filters: dict) -> Response: - resp = connection._client.request( - # Full url overrides the client url, which contains engine as a prefix. - url=connection.api_endpoint + endpoint, - method="GET", - params=filters, - ) - resp.raise_for_status() - return resp +def _get_system_engine_url( + auth: Auth, + account_name: str, + api_endpoint: str, +) -> str: + with Client( + auth=auth, + base_url=api_endpoint, + account_name=account_name, + api_endpoint=api_endpoint, + timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), + ) as client: + # return "https://api.us-east-1.dev.firebolt.io/dynamic" + url = GATEWAY_HOST_BY_ACCOUNT_NAME.format(account_name=account_name) + response = client.get(url=url) + if response.status_code == codes.NOT_FOUND: + raise AccountNotFoundError(account_name) + if response.status_code != codes.OK: + raise InterfaceError( + f"Unable to retrieve system engine endpoint {url}: " + f"{response.status_code} {response.content}" + ) + return response.json()["engineUrl"] + DYNAMIC_QUERY + + +def _get_engine_url_status_db( + system_engine: Connection, engine_name: str +) -> Tuple[str, str, str]: + with system_engine.cursor() as cursor: + cursor.execute( + """ + SELECT url, attached_to, status FROM information_schema.engines + WHERE engine_name=? + """, + [engine_name], + ) + row = cursor.fetchone() + if row is None: + raise FireboltEngineError(f"Engine with name {engine_name} doesn't exist") + engine_url, database, status = row + return str(engine_url), str(status), str(database) # Mypy check diff --git a/src/firebolt/model/engine.py b/src/firebolt/model/engine.py index 8eab5e8e35b..04c9ef378d9 100644 --- a/src/firebolt/model/engine.py +++ b/src/firebolt/model/engine.py @@ -194,8 +194,9 @@ def get_connection(self) -> Connection: """ return connect( database=self.database.name, # type: ignore # already checked by decorator + # we always have firebolt Auth as a client auth auth=self._service.client.auth, # type: ignore - engine_url=self.endpoint, + engine_name=self.name, account_name=self._service.settings.account_name, api_endpoint=self._service.settings.server, ) diff --git a/src/firebolt/service/manager.py b/src/firebolt/service/manager.py index 70657301dfb..c8d646bfb78 100644 --- a/src/firebolt/service/manager.py +++ b/src/firebolt/service/manager.py @@ -3,7 +3,6 @@ from httpx import Timeout from firebolt.client import Client, log_request, log_response, raise_on_4xx_5xx -from firebolt.client.auth import Token, UsernamePassword from firebolt.common import Settings from firebolt.service.provider import get_provider_id from firebolt.utils.util import fix_url_schema @@ -28,25 +27,8 @@ class ResourceManager: def __init__(self, settings: Optional[Settings] = None): self.settings = settings or Settings() - - auth = self.settings.auth - - # Deprecated: we shouldn't support passing credentials after 1.0 release - if auth is None: - if self.settings.access_token: - auth = Token(self.settings.access_token) - else: - # mypy checks - assert self.settings.user - assert self.settings.password - auth = UsernamePassword( - self.settings.user, - self.settings.password, - self.settings.use_token_cache, - ) - self.client = Client( - auth=auth, + auth=self.settings.auth, base_url=fix_url_schema(self.settings.server), account_name=self.settings.account_name, api_endpoint=self.settings.server, diff --git a/src/firebolt/utils/exception.py b/src/firebolt/utils/exception.py index b5e39420be1..a5e87212ecb 100644 --- a/src/firebolt/utils/exception.py +++ b/src/firebolt/utils/exception.py @@ -9,6 +9,9 @@ class FireboltEngineError(FireboltError): class EngineNotRunningError(FireboltEngineError): """Engine that's being queried is not running.""" + def __init__(self, engine_name: str): + super().__init__(f"Engine {engine_name} is not running") + class NoAttachedDatabaseError(FireboltEngineError): """Engine that's being accessed is not running. diff --git a/src/firebolt/utils/urls.py b/src/firebolt/utils/urls.py index 0149576e47d..f68f1943946 100644 --- a/src/firebolt/utils/urls.py +++ b/src/firebolt/utils/urls.py @@ -1,5 +1,4 @@ -AUTH_URL = "/auth/v1/login" -AUTH_SERVICE_ACCOUNT_URL = "/auth/v1/token" +AUTH_SERVICE_ACCOUNT_URL = "/oauth/token" DATABASES_URL = "/core/v1/account/databases" @@ -7,7 +6,7 @@ ENGINES_BY_IDS_URL = "/core/v1/engines:getByIds" ACCOUNT_URL = "/iam/v2/account" -ACCOUNT_BY_NAME_URL = "/iam/v2/accounts:getIdByName" +ACCOUNT_BY_NAME_URL = "/web/v3/account/{account_name}/resolve" ACCOUNT_ENGINE_URL = "/core/v1/accounts/{account_id}/engines/{engine_id}" ACCOUNT_ENGINE_START_URL = ACCOUNT_ENGINE_URL + ":start" @@ -29,3 +28,6 @@ PROVIDERS_URL = "/compute/v1/providers" REGIONS_URL = "/compute/v1/regions" + +GATEWAY_HOST_BY_ACCOUNT_NAME = "/web/v3/account/{account_name}/engineUrl" +DYNAMIC_QUERY = "/dynamic/query" diff --git a/src/firebolt/utils/util.py b/src/firebolt/utils/util.py index 96ac7a1ae64..c7b7533b847 100644 --- a/src/firebolt/utils/util.py +++ b/src/firebolt/utils/util.py @@ -1,11 +1,8 @@ -from functools import lru_cache, partial, wraps -from typing import TYPE_CHECKING, Any, Callable, Optional, Type, TypeVar +from functools import lru_cache +from typing import TYPE_CHECKING, Callable, Type, TypeVar -import trio from httpx import URL -from firebolt.utils.exception import ConfigurationError - T = TypeVar("T") @@ -72,21 +69,18 @@ def fix_url_schema(url: str) -> str: return url if url.startswith("http") else f"https://{url}" -def async_to_sync(f: Callable) -> Callable: - """Convert async function to sync. +def get_auth_endpoint(api_endpoint: URL) -> URL: + """Create auth endpoint from api endpoint. Args: - f (Callable): function to convert + api_endpoint (URL): provided API endpoint Returns: - Callable: regular function, which can be executed synchronously + URL: authentication endpoint """ - - @wraps(f) - def sync(*args: Any, **kwargs: Any) -> Any: - return trio.run(partial(f, *args, **kwargs)) - - return sync + return api_endpoint.copy_with( + host=".".join(["id"] + api_endpoint.host.split(".")[1:]) + ) def merge_urls(base: URL, merge: URL) -> URL: @@ -105,12 +99,3 @@ def merge_urls(base: URL, merge: URL) -> URL: merge_raw_path = base.raw_path + merge.raw_path.lstrip(b"/") return base.copy_with(raw_path=merge_raw_path) return merge - - -def validate_engine_name_and_url( - engine_name: Optional[str], engine_url: Optional[str] -) -> None: - if engine_name and engine_url: - raise ConfigurationError( - "Both engine_name and engine_url are provided. Provide only one to connect." - ) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 810125cf4a9..05a86144f58 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -2,20 +2,15 @@ from os import environ from pytest import fixture -from pytest_asyncio import fixture as async_fixture -from firebolt.client.auth import ServiceAccount, UsernamePassword +from firebolt.client.auth import ClientCredentials from firebolt.service.manager import Settings LOGGER = getLogger(__name__) -ENGINE_URL_ENV = "ENGINE_URL" ENGINE_NAME_ENV = "ENGINE_NAME" -STOPPED_ENGINE_URL_ENV = "STOPPED_ENGINE_URL" STOPPED_ENGINE_NAME_ENV = "STOPPED_ENGINE_NAME" DATABASE_NAME_ENV = "DATABASE_NAME" -USER_NAME_ENV = "USER_NAME" -PASSWORD_ENV = "PASSWORD" ACCOUNT_NAME_ENV = "ACCOUNT_NAME" API_ENDPOINT_ENV = "API_ENDPOINT" SERVICE_ID_ENV = "SERVICE_ID" @@ -28,26 +23,16 @@ def must_env(var_name: str) -> str: return environ[var_name] -@async_fixture(scope="session") -def rm_settings(api_endpoint, username, password) -> Settings: +@fixture(scope="session") +def rm_settings(api_endpoint, auth, account_name) -> Settings: return Settings( + account_name=account_name, server=api_endpoint, - user=username, - password=password, + auth=auth, default_region="us-east-1", ) -@fixture(scope="session") -def engine_url() -> str: - return must_env(ENGINE_URL_ENV) - - -@fixture(scope="session") -def stopped_engine_url() -> str: - return must_env(STOPPED_ENGINE_URL_ENV) - - @fixture(scope="session") def engine_name() -> str: return must_env(ENGINE_NAME_ENV) @@ -55,7 +40,7 @@ def engine_name() -> str: @fixture(scope="session") def stopped_engine_name() -> str: - return must_env(STOPPED_ENGINE_URL_ENV) + return must_env(STOPPED_ENGINE_NAME_ENV) @fixture(scope="session") @@ -63,16 +48,6 @@ def database_name() -> str: return must_env(DATABASE_NAME_ENV) -@fixture(scope="session") -def username() -> str: - return must_env(USER_NAME_ENV) - - -@fixture(scope="session") -def password() -> str: - return must_env(PASSWORD_ENV) - - @fixture(scope="session") def account_name() -> str: return must_env(ACCOUNT_NAME_ENV) @@ -93,11 +68,6 @@ def service_secret() -> str: return must_env(SERVICE_SECRET_ENV) -@fixture -def service_auth(service_id, service_secret) -> ServiceAccount: - return ServiceAccount(service_id, service_secret) - - @fixture(scope="session") -def password_auth(username, password) -> UsernamePassword: - return UsernamePassword(username, password) +def auth(service_id, service_secret) -> ClientCredentials: + return ClientCredentials(service_id, service_secret) diff --git a/tests/integration/dbapi/async/conftest.py b/tests/integration/dbapi/async/conftest.py index 76bae447748..3bfbbc7ff4d 100644 --- a/tests/integration/dbapi/async/conftest.py +++ b/tests/integration/dbapi/async/conftest.py @@ -1,75 +1,67 @@ -from pytest_asyncio import fixture as async_fixture +from pytest import fixture from firebolt.async_db import Connection, connect from firebolt.client.auth.base import Auth -@async_fixture -async def username_password_connection( - engine_url: str, +@fixture +async def connection( + engine_name: str, database_name: str, - password_auth: Auth, + auth: Auth, account_name: str, api_endpoint: str, ) -> Connection: async with await connect( - engine_url=engine_url, + engine_name=engine_name, database=database_name, - auth=password_auth, + auth=auth, account_name=account_name, api_endpoint=api_endpoint, ) as connection: yield connection -@async_fixture -async def connection( - engine_url: str, - database_name: str, - password_auth: Auth, +@fixture +async def connection_no_db( + engine_name: str, + auth: Auth, account_name: str, api_endpoint: str, ) -> Connection: async with await connect( - engine_url=engine_url, - database=database_name, - auth=password_auth, + engine_name=engine_name, + auth=auth, account_name=account_name, api_endpoint=api_endpoint, ) as connection: yield connection -@async_fixture -async def connection_engine_name( - engine_name: str, +@fixture +async def connection_system_engine( database_name: str, - password_auth: Auth, + auth: Auth, account_name: str, api_endpoint: str, ) -> Connection: - async with await connect( - engine_name=engine_name, database=database_name, - auth=password_auth, + auth=auth, account_name=account_name, api_endpoint=api_endpoint, ) as connection: yield connection -@async_fixture -async def connection_no_engine( - database_name: str, - password_auth: Auth, +@fixture +async def connection_system_engine_no_db( + auth: Auth, account_name: str, api_endpoint: str, ) -> Connection: - async with await connect( - database=database_name, - auth=password_auth, + auth=auth, account_name=account_name, api_endpoint=api_endpoint, ) as connection: diff --git a/tests/integration/dbapi/async/test_errors_async.py b/tests/integration/dbapi/async/test_errors_async.py index be19f63a9b7..127fd327532 100644 --- a/tests/integration/dbapi/async/test_errors_async.py +++ b/tests/integration/dbapi/async/test_errors_async.py @@ -1,13 +1,12 @@ -from httpx import ConnectError -from pytest import mark, raises +from pytest import raises from firebolt.async_db import Connection, connect -from firebolt.client.auth import UsernamePassword +from firebolt.client.auth import ClientCredentials from firebolt.utils.exception import ( AccountNotFoundError, EngineNotRunningError, - FireboltDatabaseError, FireboltEngineError, + InterfaceError, OperationalError, ) @@ -15,7 +14,7 @@ async def test_invalid_account( database_name: str, engine_name: str, - password_auth: UsernamePassword, + auth: ClientCredentials, api_endpoint: str, ) -> None: """Connection properly reacts to invalid account error.""" @@ -23,8 +22,8 @@ async def test_invalid_account( with raises(AccountNotFoundError) as exc_info: async with await connect( database=database_name, - engine_name=engine_name, # Omit engine_url to force account_id lookup. - auth=password_auth, + engine_name=engine_name, + auth=auth, account_name=account_name, api_endpoint=api_endpoint, ) as connection: @@ -35,29 +34,10 @@ async def test_invalid_account( ), "Invalid account error message." -async def test_engine_url_not_exists( - engine_url: str, - database_name: str, - password_auth: UsernamePassword, - account_name: str, - api_endpoint: str, -) -> None: - """Connection properly reacts to invalid engine url error.""" - async with await connect( - engine_url=engine_url + "_", - database=database_name, - auth=password_auth, - account_name=account_name, - api_endpoint=api_endpoint, - ) as connection: - with raises(ConnectError): - await connection.cursor().execute("show tables") - - async def test_engine_name_not_exists( engine_name: str, database_name: str, - password_auth: UsernamePassword, + auth: ClientCredentials, account_name: str, api_endpoint: str, ) -> None: @@ -66,7 +46,7 @@ async def test_engine_name_not_exists( async with await connect( engine_name=engine_name + "_________", database=database_name, - auth=password_auth, + auth=auth, account_name=account_name, api_endpoint=api_endpoint, ) as connection: @@ -74,45 +54,46 @@ async def test_engine_name_not_exists( async def test_engine_stopped( - stopped_engine_url: str, + stopped_engine_name: str, database_name: str, - password_auth: UsernamePassword, + auth: ClientCredentials, account_name: str, api_endpoint: str, ) -> None: """Connection properly reacts to invalid engine name error.""" with raises(EngineNotRunningError): async with await connect( - engine_url=stopped_engine_url, + engine_name=stopped_engine_name, database=database_name, - auth=password_auth, + auth=auth, account_name=account_name, api_endpoint=api_endpoint, ) as connection: await connection.cursor().execute("show tables") -@mark.skip(reason="Behaviour is different in prod vs dev") async def test_database_not_exists( - engine_url: str, + engine_name: str, database_name: str, - password_auth: UsernamePassword, + auth: ClientCredentials, api_endpoint: str, + account_name: str, ) -> None: """Connection properly reacts to invalid database error.""" new_db_name = database_name + "_" - async with await connect( - engine_url=engine_url, - database=new_db_name, - auth=password_auth, - api_endpoint=api_endpoint, - ) as connection: - with raises(FireboltDatabaseError) as exc_info: + with raises(InterfaceError) as exc_info: + async with await connect( + engine_name=engine_name, + database=new_db_name, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + ) as connection: await connection.cursor().execute("show tables") - assert ( - str(exc_info.value) == f"Database {new_db_name} does not exist" - ), "Invalid database name error message." + assert ( + str(exc_info.value) == f"Database {new_db_name} does not exist" + ), "Invalid database name error message." async def test_sql_error(connection: Connection) -> None: diff --git a/tests/integration/dbapi/async/test_queries_async.py b/tests/integration/dbapi/async/test_queries_async.py index 5b70e9f23bf..4a961cddd2c 100644 --- a/tests/integration/dbapi/async/test_queries_async.py +++ b/tests/integration/dbapi/async/test_queries_async.py @@ -1,6 +1,6 @@ from datetime import date, datetime from decimal import Decimal -from typing import Any, List +from typing import List from pytest import mark, raises @@ -13,67 +13,13 @@ ) from firebolt.async_db.cursor import QueryStatus from firebolt.common._types import ColType, Column +from tests.integration.dbapi.utils import assert_deep_eq VALS_TO_INSERT_2 = ",".join( [f"({i}, {i-3}, '{val}')" for (i, val) in enumerate(range(4, 1000))] ) LONG_INSERT = f"INSERT INTO test_tbl VALUES {VALS_TO_INSERT_2}" -CREATE_EXTERNAL_TABLE = """CREATE EXTERNAL TABLE IF NOT EXISTS ex_lineitem ( - l_orderkey LONG, - l_partkey LONG, - l_suppkey LONG, - l_linenumber INT, - l_quantity LONG, - l_extendedprice LONG, - l_discount LONG, - l_tax LONG, - l_returnflag TEXT, - l_linestatus TEXT, - l_shipdate TEXT, - l_commitdate TEXT, - l_receiptdate TEXT, - l_shipinstruct TEXT, - l_shipmode TEXT, - l_comment TEXT -) -URL = 's3://firebolt-publishing-public/samples/tpc-h/parquet/lineitem/' -OBJECT_PATTERN = '*.parquet' -TYPE = (PARQUET);""" - -CREATE_FACT_TABLE = """CREATE FACT TABLE IF NOT EXISTS lineitem ( --- In this example, these fact table columns --- map directly to the external table columns. - l_orderkey LONG, - l_partkey LONG, - l_suppkey LONG, - l_linenumber INT, - l_quantity LONG, - l_extendedprice LONG, - l_discount LONG, - l_tax LONG, - l_returnflag TEXT, - l_linestatus TEXT, - l_shipdate TEXT, - l_commitdate TEXT, - l_receiptdate TEXT, - l_shipinstruct TEXT, - l_shipmode TEXT, - l_comment TEXT -) -PRIMARY INDEX - l_orderkey, - l_linenumber; -""" - - -def assert_deep_eq(got: Any, expected: Any, msg: str) -> bool: - if type(got) == list and type(expected) == list: - all([assert_deep_eq(f, s, msg) for f, s in zip(got, expected)]) - assert ( - type(got) == type(expected) and got == expected - ), f"{msg}: {got}(got) != {expected}(expected)" - async def status_loop( query_id: str, @@ -98,25 +44,8 @@ async def status_loop( ), f"Failed {query}. Got {status} rather than {final_status}." -async def test_connect_engine_name( - connection_engine_name: Connection, - all_types_query: str, - all_types_query_description: List[Column], - all_types_query_response: List[ColType], - timezone_name: str, -) -> None: - """Connecting with engine name is handled properly.""" - await test_select( - connection_engine_name, - all_types_query, - all_types_query_description, - all_types_query_response, - timezone_name, - ) - - -async def test_connect_no_engine( - connection_no_engine: Connection, +async def test_connect_no_db( + connection_no_db: Connection, all_types_query: str, all_types_query_description: List[Column], all_types_query_response: List[ColType], @@ -124,7 +53,7 @@ async def test_connect_no_engine( ) -> None: """Connecting with engine name is handled properly.""" await test_select( - connection_no_engine, + connection_no_db, all_types_query, all_types_query_description, all_types_query_response, @@ -198,21 +127,18 @@ async def test_long_query( async def test_drop_create(connection: Connection) -> None: """Create and drop table/index queries are handled properly.""" - async def test_query(c: Cursor, query: str) -> None: + async def test_query(c: Cursor, query: str, empty_response=True) -> None: await c.execute(query) assert c.description == None - # This is inconsistent, commenting for now - # assert c.rowcount == -1 + assert c.rowcount == (-1 if empty_response else 0) """Create table query is handled properly""" with connection.cursor() as c: # Cleanup - await c.execute("DROP JOIN INDEX IF EXISTS test_drop_create_async_db_join_idx") - await c.execute( - "DROP AGGREGATING INDEX IF EXISTS test_drop_create_async_db_agg_idx" - ) - await c.execute("DROP TABLE IF EXISTS test_drop_create_async_tb") - await c.execute("DROP TABLE IF EXISTS test_drop_create_async_tb_dim") + await c.execute("DROP JOIN INDEX IF EXISTS test_db_join_idx") + await c.execute("DROP AGGREGATING INDEX IF EXISTS test_db_agg_idx") + await c.execute("DROP TABLE IF EXISTS test_drop_create_async") + await c.execute("DROP TABLE IF EXISTS test_drop_create_async_dim") # Fact table await test_query( @@ -240,6 +166,7 @@ async def test_query(c: Cursor, query: str) -> None: c, "CREATE AGGREGATING INDEX test_db_agg_idx ON " "test_drop_create_async(id, sum(f), count(dt))", + empty_response=False, ) # Drop join index diff --git a/tests/integration/dbapi/async/test_system_engine.py b/tests/integration/dbapi/async/test_system_engine.py new file mode 100644 index 00000000000..f04f27d04c3 --- /dev/null +++ b/tests/integration/dbapi/async/test_system_engine.py @@ -0,0 +1,213 @@ +from typing import List + +from pytest import fixture, mark, raises + +from firebolt.async_db import Connection +from firebolt.common._types import ColType, Column +from firebolt.utils.exception import OperationalError +from tests.integration.dbapi.utils import assert_deep_eq + + +@fixture +def db_name(database_name): + return database_name + "_system_test" + + +@fixture +def second_db_name(database_name): + return database_name + "_system_test_two" + + +@fixture +def region(): + return "us-east-1" + + +@fixture +def engine_name(engine_name): + return engine_name + "_system_test" + + +@fixture +async def setup_dbs( + connection_system_engine, db_name, second_db_name, engine_name, region +): + with connection_system_engine.cursor() as cursor: + + await cursor.execute(f"DROP DATABASE IF EXISTS {db_name}") + await cursor.execute(f"DROP DATABASE IF EXISTS {second_db_name}") + + await cursor.execute(create_database(name=db_name)) + + await cursor.execute(create_engine(engine_name, engine_specs(region))) + + await cursor.execute( + create_database(name=second_db_name, specs=db_specs(region, engine_name)) + ) + + yield + + await cursor.execute(f"DROP ENGINE IF EXISTS {engine_name}") + await cursor.execute(f"DROP DATABASE IF EXISTS {db_name}") + await cursor.execute(f"DROP DATABASE IF EXISTS {second_db_name}") + + +async def test_system_engine( + connection_system_engine: Connection, + all_types_query: str, + all_types_query_description: List[Column], + all_types_query_system_engine_response: List[ColType], + timezone_name: str, +) -> None: + """Connecting with engine name is handled properly.""" + with connection_system_engine.cursor() as c: + assert await c.execute(all_types_query) == 1, "Invalid row count returned" + assert c.rowcount == 1, "Invalid rowcount value" + data = await c.fetchall() + assert len(data) == c.rowcount, "Invalid data length" + assert_deep_eq(data, all_types_query_system_engine_response, "Invalid data") + assert c.description == all_types_query_description, "Invalid description value" + assert len(data[0]) == len(c.description), "Invalid description length" + assert len(await c.fetchall()) == 0, "Redundant data returned by fetchall" + + # Different fetch types + await c.execute(all_types_query) + assert ( + await c.fetchone() == all_types_query_system_engine_response[0] + ), "Invalid fetchone data" + assert await c.fetchone() is None, "Redundant data returned by fetchone" + + await c.execute(all_types_query) + assert len(await c.fetchmany(0)) == 0, "Invalid data size returned by fetchmany" + data = await c.fetchmany() + assert len(data) == 1, "Invalid data size returned by fetchmany" + assert_deep_eq( + data, + all_types_query_system_engine_response, + "Invalid data returned by fetchmany", + ) + + if connection_system_engine.database: + await c.execute("show tables") + with raises(OperationalError): + await c.execute("create table test(id int) primary index id") + else: + await c.execute("show databases") + with raises(OperationalError): + await c.execute("show tables") + + +async def test_system_engine_no_db( + connection_system_engine_no_db: Connection, + all_types_query: str, + all_types_query_description: List[Column], + all_types_query_system_engine_response: List[ColType], + timezone_name: str, +) -> None: + """Connecting with engine name is handled properly.""" + await test_system_engine( + connection_system_engine_no_db, + all_types_query, + all_types_query_description, + all_types_query_system_engine_response, + timezone_name, + ) + + +def engine_specs(region): + return f"REGION = '{region}' " "SPEC = 'B1' " "SCALE = 1" + + +def create_database(name, specs=None): + query = f"CREATE DATABASE {name}" + query += f" WITH {specs}" if specs else "" + return query + + +def create_engine(name, specs=None): + query = f"CREATE ENGINE {name}" + query += f" WITH {specs}" if specs else "" + return query + + +def db_specs(region, attached_engine): + return ( + f"REGION = '{region}' " + f"ATTACHED_ENGINES = ('{attached_engine}') " + "DESCRIPTION = 'Sample description'" + ) + + +@mark.parametrize( + "query", + ["CREATE DIMENSION TABLE dummy(id INT)"], +) +async def test_query_errors(connection_system_engine, query): + with connection_system_engine.cursor() as cursor: + with raises(OperationalError): + await cursor.execute(query) + + +@mark.xdist_group(name="system_engine") +async def test_show_databases(setup_dbs, connection_system_engine, db_name): + with connection_system_engine.cursor() as cursor: + + await cursor.execute("SHOW DATABASES") + + dbs = [row[0] for row in await cursor.fetchall()] + + assert db_name in dbs + assert f"{db_name}_two" in dbs + + +@mark.xdist_group(name="system_engine") +async def test_detach_engine( + setup_dbs, connection_system_engine, engine_name, second_db_name +): + async def check_engine_exists(cursor, engine_name, db_name): + await cursor.execute("SHOW ENGINES") + engines = await cursor.fetchall() + # Results have the following columns + # engine_name, region, spec, scale, status, attached_to, version + assert engine_name in [row[0] for row in engines] + assert (engine_name, db_name) in [(row[0], row[5]) for row in engines] + + with connection_system_engine.cursor() as cursor: + await check_engine_exists(cursor, engine_name, db_name=second_db_name) + await cursor.execute(f"DETACH ENGINE {engine_name} FROM {second_db_name}") + + # When engine not attached db is - + await check_engine_exists(cursor, engine_name, db_name="-") + + await cursor.execute(f"ATTACH ENGINE {engine_name} TO {second_db_name}") + await check_engine_exists(cursor, engine_name, db_name=second_db_name) + + +@mark.xdist_group(name="system_engine") +async def test_alter_engine(setup_dbs, connection_system_engine, engine_name): + with connection_system_engine.cursor() as cursor: + await cursor.execute(f"ALTER ENGINE {engine_name} SET AUTO_STOP = 60") + + await cursor.execute( + "SELECT engine_name, auto_stop FROM information_schema.engines" + ) + engines = await cursor.fetchall() + assert [engine_name, 3600] in engines + + +@mark.xdist_group(name="system_engine") +async def test_start_stop_engine(setup_dbs, connection_system_engine, engine_name): + async def check_engine_status(cursor, engine_name, status): + await cursor.execute("SHOW ENGINES") + engines = await cursor.fetchall() + # Results have the following columns + # engine_name, region, spec, scale, status, attached_to, version + assert engine_name in [row[0] for row in engines] + assert (engine_name, status) in [(row[0], row[4]) for row in engines] + + with connection_system_engine.cursor() as cursor: + await check_engine_status(cursor, engine_name, "Stopped") + await cursor.execute(f"START ENGINE {engine_name}") + await check_engine_status(cursor, engine_name, "Running") + await cursor.execute(f"STOP ENGINE {engine_name}") + await check_engine_status(cursor, engine_name, "Stopped") diff --git a/tests/integration/dbapi/conftest.py b/tests/integration/dbapi/conftest.py index b61bedf4077..7495af88b94 100644 --- a/tests/integration/dbapi/conftest.py +++ b/tests/integration/dbapi/conftest.py @@ -66,16 +66,14 @@ def all_types_query() -> str: "1.2345678901234 as float64, " "'text' as \"string\", " "CAST('2021-03-28' AS DATE) as \"date\", " - "CAST('1860-03-04' AS DATE_EXT) as \"date32\"," "pgdate '0001-01-01' as \"pgdate\", " "CAST('2019-07-31 01:01:01' AS DATETIME) as \"datetime\", " - "CAST('2019-07-31 01:01:01.1234' AS TIMESTAMP_EXT(4)) as \"datetime64\", " "CAST('1111-01-05 17:04:42.123456' as timestampntz) as \"timestampntz\", " "'1111-01-05 17:04:42.123456'::timestamptz as \"timestamptz\", " 'true as "boolean", ' "[1,2,3,4] as \"array\", cast('1231232.123459999990457054844258706536' as " 'decimal(38,30)) as "decimal", ' - 'cast(null as int) as "nullable", ' + 'null as "nullable", ' "'abc123'::bytea as \"bytea\"" ) @@ -95,16 +93,14 @@ def all_types_query_description() -> List[Column]: Column("float64", float, None, None, None, None, None), Column("string", str, None, None, None, None, None), Column("date", date, None, None, None, None, None), - Column("date32", date, None, None, None, None, None), Column("pgdate", date, None, None, None, None, None), Column("datetime", datetime, None, None, None, None, None), - Column("datetime64", datetime, None, None, None, None, None), Column("timestampntz", datetime, None, None, None, None, None), Column("timestamptz", datetime, None, None, None, None, None), Column("boolean", bool, None, None, None, None, None), Column("array", ARRAY(int), None, None, None, None, None), Column("decimal", DECIMAL(38, 30), None, None, None, None, None), - Column("nullable", int, None, None, None, None, None), + Column("nullable", str, None, None, None, None, None), Column("bytea", bytes, None, None, None, None, None), ] @@ -122,13 +118,11 @@ def all_types_query_response(timezone_offset_seconds: int) -> List[ColType]: 30000000000, -30000000000, 1.23, - 1.23456789012, + 1.2345678901234, "text", date(2021, 3, 28), - date(1860, 3, 4), date(1, 1, 1), datetime(2019, 7, 31, 1, 1, 1), - datetime(2019, 7, 31, 1, 1, 1, 123400), datetime(1111, 1, 5, 17, 4, 42, 123456), datetime( 1111, @@ -149,6 +143,37 @@ def all_types_query_response(timezone_offset_seconds: int) -> List[ColType]: ] +@fixture +def all_types_query_system_engine_response( + timezone_offset_seconds: int, +) -> List[ColType]: + return [ + [ + 1, + -1, + 257, + -257, + 80000, + -80000, + 30000000000, + -30000000000, + 1.23, + 1.23456789012, + "text", + date(2021, 3, 28), + date(1, 1, 1), + datetime(2019, 7, 31, 1, 1, 1), + datetime(1111, 1, 5, 17, 4, 42, 123456), + datetime(1111, 1, 5, 17, 4, 42, 123456, tzinfo=timezone.utc), + True, + [1, 2, 3, 4], + Decimal("1231232.123459999990457054844258706536"), + None, + b"abc123", + ] + ] + + @fixture def timezone_name() -> str: return "Asia/Calcutta" diff --git a/tests/integration/dbapi/sync/conftest.py b/tests/integration/dbapi/sync/conftest.py index e8c4447ea08..2b561c7fb29 100644 --- a/tests/integration/dbapi/sync/conftest.py +++ b/tests/integration/dbapi/sync/conftest.py @@ -5,91 +5,64 @@ @fixture -def username_password_connection( - engine_url: str, - database_name: str, - password_auth: Auth, - account_name: str, - api_endpoint: str, -) -> Connection: - connection = connect( - engine_url=engine_url, - database=database_name, - auth=password_auth, - account_name=account_name, - api_endpoint=api_endpoint, - ) - yield connection - connection.close() - - -@fixture -async def connection( - engine_url: str, +def connection( + engine_name: str, database_name: str, - password_auth: Auth, + auth: Auth, account_name: str, api_endpoint: str, ) -> Connection: - connection = connect( - engine_url=engine_url, + with connect( + engine_name=engine_name, database=database_name, - auth=password_auth, + auth=auth, account_name=account_name, api_endpoint=api_endpoint, - ) - yield connection - connection.close() + ) as connection: + yield connection @fixture -def connection_engine_name( +def connection_no_db( engine_name: str, - database_name: str, - password_auth: Auth, + auth: Auth, account_name: str, api_endpoint: str, ) -> Connection: - connection = connect( + with connect( engine_name=engine_name, - database=database_name, - auth=password_auth, + auth=auth, account_name=account_name, api_endpoint=api_endpoint, - ) - yield connection - connection.close() + ) as connection: + yield connection @fixture -def connection_no_engine( +def connection_system_engine( database_name: str, - password_auth: Auth, + auth: Auth, account_name: str, api_endpoint: str, ) -> Connection: - connection = connect( + with connect( database=database_name, - auth=password_auth, + auth=auth, account_name=account_name, api_endpoint=api_endpoint, - ) - yield connection - connection.close() + ) as connection: + yield connection -@fixture(scope="session") -def connection_system_engine( - password_auth: Auth, +@fixture +def connection_system_engine_no_db( + auth: Auth, account_name: str, api_endpoint: str, ) -> Connection: - connection = connect( - database="dummy", - engine_name="system", - auth=password_auth, + with connect( + auth=auth, account_name=account_name, api_endpoint=api_endpoint, - ) - yield connection - connection.close() + ) as connection: + yield connection diff --git a/tests/integration/dbapi/sync/test_errors.py b/tests/integration/dbapi/sync/test_errors.py index dd5017c0202..088041d9833 100644 --- a/tests/integration/dbapi/sync/test_errors.py +++ b/tests/integration/dbapi/sync/test_errors.py @@ -1,7 +1,6 @@ -from httpx import ConnectError from pytest import mark, raises -from firebolt.client.auth import UsernamePassword +from firebolt.client.auth import ClientCredentials from firebolt.db import Connection, connect from firebolt.utils.exception import ( AccountNotFoundError, @@ -15,7 +14,7 @@ def test_invalid_account( database_name: str, engine_name: str, - password_auth: UsernamePassword, + auth: ClientCredentials, api_endpoint: str, ) -> None: """Connection properly reacts to invalid account error.""" @@ -24,7 +23,7 @@ def test_invalid_account( with connect( database=database_name, engine_name=engine_name, # Omit engine_url to force account_id lookup. - auth=password_auth, + auth=auth, account_name=account_name, api_endpoint=api_endpoint, ) as connection: @@ -35,52 +34,39 @@ def test_invalid_account( ), "Invalid account error message." -def test_engine_url_not_exists( - engine_url: str, - database_name: str, - password_auth: UsernamePassword, - api_endpoint: str, -) -> None: - """Connection properly reacts to invalid engine url error.""" - with connect( - engine_url=engine_url + "_", - database=database_name, - auth=password_auth, - api_endpoint=api_endpoint, - ) as connection: - with raises(ConnectError): - connection.cursor().execute("show tables") - - def test_engine_name_not_exists( engine_name: str, database_name: str, - password_auth: UsernamePassword, + auth: ClientCredentials, + account_name: str, api_endpoint: str, ) -> None: """Connection properly reacts to invalid engine name error.""" with raises(FireboltEngineError): with connect( + account_name=account_name, engine_name=engine_name + "_________", database=database_name, - auth=password_auth, + auth=auth, api_endpoint=api_endpoint, ) as connection: connection.cursor().execute("show tables") def test_engine_stopped( - stopped_engine_url: str, + stopped_engine_name: str, database_name: str, - password_auth: UsernamePassword, + auth: ClientCredentials, + account_name: str, api_endpoint: str, ) -> None: """Connection properly reacts to engine not running error.""" with raises(EngineNotRunningError): with connect( - engine_url=stopped_engine_url, + account_name=account_name, + engine_name=stopped_engine_name, database=database_name, - auth=password_auth, + auth=auth, api_endpoint=api_endpoint, ) as connection: connection.cursor().execute("show tables") @@ -90,15 +76,17 @@ def test_engine_stopped( def test_database_not_exists( engine_url: str, database_name: str, - password_auth: UsernamePassword, + auth: ClientCredentials, + account_name: str, api_endpoint: str, ) -> None: """Connection properly reacts to invalid database error.""" new_db_name = database_name + "_" with connect( + account_name=account_name, engine_url=engine_url, database=new_db_name, - auth=password_auth, + auth=auth, api_endpoint=api_endpoint, ) as connection: with raises(FireboltDatabaseError) as exc_info: diff --git a/tests/integration/dbapi/sync/test_queries.py b/tests/integration/dbapi/sync/test_queries.py index 03e0ff22007..32c7320fb16 100644 --- a/tests/integration/dbapi/sync/test_queries.py +++ b/tests/integration/dbapi/sync/test_queries.py @@ -16,6 +16,7 @@ OperationalError, connect, ) +from tests.integration.dbapi.utils import assert_deep_eq VALS_TO_INSERT = ",".join([f"({i},'{val}')" for (i, val) in enumerate(range(1, 360))]) LONG_INSERT = f"INSERT INTO test_tbl VALUES {VALS_TO_INSERT}" @@ -52,8 +53,8 @@ def status_loop( ), f"Failed {query}. Got {status} rather than {final_status}." -def test_connect_engine_name( - connection_engine_name: Connection, +def test_connect_no_db( + connection_no_db: Connection, all_types_query: str, all_types_query_description: List[Column], all_types_query_response: List[ColType], @@ -61,24 +62,7 @@ def test_connect_engine_name( ) -> None: """Connecting with engine name is handled properly.""" test_select( - connection_engine_name, - all_types_query, - all_types_query_description, - all_types_query_response, - timezone_name, - ) - - -def test_connect_no_engine( - connection_no_engine: Connection, - all_types_query: str, - all_types_query_description: List[Column], - all_types_query_response: List[ColType], - timezone_name: str, -) -> None: - """Connecting with engine name is handled properly.""" - test_select( - connection_no_engine, + connection_no_db, all_types_query, all_types_query_description, all_types_query_response, @@ -149,11 +133,10 @@ def test_long_query( def test_drop_create(connection: Connection) -> None: """Create and drop table/index queries are handled properly.""" - def test_query(c: Cursor, query: str) -> None: + def test_query(c: Cursor, query: str, empty_response=True) -> None: c.execute(query) assert c.description == None - # Inconsistent behaviour in Firebolt - # assert c.rowcount == -1 + assert c.rowcount == (-1 if empty_response else 0) """Create table query is handled properly""" with connection.cursor() as c: @@ -189,6 +172,7 @@ def test_query(c: Cursor, query: str) -> None: c, "CREATE AGGREGATING INDEX test_drop_create_db_agg_idx ON " "test_drop_create_tb(id, sum(f), count(dt))", + empty_response=False, ) # Drop join index @@ -379,9 +363,9 @@ def test_set_invalid_parameter(connection: Connection): # Run test multiple times since the issue is flaky @mark.parametrize("_", range(5)) def test_anyio_backend_import_issue( - engine_url: str, + engine_name: str, database_name: str, - password_auth: Auth, + auth: Auth, account_name: str, api_endpoint: str, _: int, @@ -392,13 +376,13 @@ def test_anyio_backend_import_issue( exceptions = [] def run_query(idx: int): - nonlocal password_auth, database_name, engine_url, account_name, api_endpoint + nonlocal auth, database_name, engine_name, account_name, api_endpoint try: with connect( - auth=password_auth, + auth=auth, database=database_name, account_name=account_name, - engine_url=engine_url, + engine_name=engine_name, api_endpoint=api_endpoint, ) as c: cursor = c.cursor() @@ -469,9 +453,9 @@ async def test_server_side_async_execution_get_status( def test_multi_thread_connection_sharing( - engine_url: str, + engine_name: str, database_name: str, - password_auth: Auth, + auth: Auth, account_name: str, api_endpoint: str, ) -> None: @@ -485,10 +469,10 @@ def test_multi_thread_connection_sharing( exceptions = [] connection = connect( - auth=password_auth, + auth=auth, database=database_name, account_name=account_name, - engine_url=engine_url, + engine_name=engine_name, api_endpoint=api_endpoint, ) diff --git a/tests/integration/dbapi/sync/test_system_engine.py b/tests/integration/dbapi/sync/test_system_engine.py index 2474c323dfc..58b4ec10418 100644 --- a/tests/integration/dbapi/sync/test_system_engine.py +++ b/tests/integration/dbapi/sync/test_system_engine.py @@ -1,29 +1,34 @@ +from typing import List + from pytest import fixture, mark, raises +from firebolt.common._types import ColType, Column +from firebolt.db import Connection from firebolt.utils.exception import OperationalError +from tests.integration.dbapi.utils import assert_deep_eq -@fixture(scope="module") +@fixture def db_name(database_name): return database_name + "_system_test" -@fixture(scope="module") +@fixture def second_db_name(database_name): return database_name + "_system_test_two" -@fixture(scope="module") +@fixture def region(): return "us-east-1" -@fixture(scope="module") +@fixture def engine_name(engine_name): return engine_name + "_system_test" -@fixture(scope="module") +@fixture def setup_dbs(connection_system_engine, db_name, second_db_name, engine_name, region): with connection_system_engine.cursor() as cursor: @@ -42,6 +47,68 @@ def setup_dbs(connection_system_engine, db_name, second_db_name, engine_name, re cursor.execute(f"DROP DATABASE IF EXISTS {second_db_name}") +def test_system_engine( + connection_system_engine: Connection, + all_types_query: str, + all_types_query_description: List[Column], + all_types_query_system_engine_response: List[ColType], + timezone_name: str, +) -> None: + """Connecting with engine name is handled properly.""" + with connection_system_engine.cursor() as c: + assert c.execute(all_types_query) == 1, "Invalid row count returned" + assert c.rowcount == 1, "Invalid rowcount value" + data = c.fetchall() + assert len(data) == c.rowcount, "Invalid data length" + assert_deep_eq(data, all_types_query_system_engine_response, "Invalid data") + assert c.description == all_types_query_description, "Invalid description value" + assert len(data[0]) == len(c.description), "Invalid description length" + assert len(c.fetchall()) == 0, "Redundant data returned by fetchall" + + # Different fetch types + c.execute(all_types_query) + assert ( + c.fetchone() == all_types_query_system_engine_response[0] + ), "Invalid fetchone data" + assert c.fetchone() is None, "Redundant data returned by fetchone" + + c.execute(all_types_query) + assert len(c.fetchmany(0)) == 0, "Invalid data size returned by fetchmany" + data = c.fetchmany() + assert len(data) == 1, "Invalid data size returned by fetchmany" + assert_deep_eq( + data, + all_types_query_system_engine_response, + "Invalid data returned by fetchmany", + ) + + if connection_system_engine.database: + c.execute("show tables") + with raises(OperationalError): + c.execute("create table test(id int) primary index id") + else: + c.execute("show databases") + with raises(OperationalError): + c.execute("show tables") + + +def test_system_engine_no_db( + connection_system_engine_no_db: Connection, + all_types_query: str, + all_types_query_description: List[Column], + all_types_query_system_engine_response: List[ColType], + timezone_name: str, +) -> None: + """Connecting with engine name is handled properly.""" + test_system_engine( + connection_system_engine_no_db, + all_types_query, + all_types_query_description, + all_types_query_system_engine_response, + timezone_name, + ) + + def engine_specs(region): return f"REGION = '{region}' " "SPEC = 'B1' " "SCALE = 1" @@ -114,11 +181,11 @@ def check_engine_exists(cursor, engine_name, db_name): @mark.xdist_group(name="system_engine") def test_alter_engine(setup_dbs, connection_system_engine, engine_name): with connection_system_engine.cursor() as cursor: - cursor.execute(f"ALTER ENGINE {engine_name} SET SPEC = B2") + cursor.execute(f"ALTER ENGINE {engine_name} SET AUTO_STOP = 60") - cursor.execute("SHOW ENGINES") + cursor.execute("SELECT engine_name, auto_stop FROM information_schema.engines") engines = cursor.fetchall() - assert (engine_name, "B2") in [(row[0], row[2]) for row in engines] + assert [engine_name, 3600] in engines @mark.xdist_group(name="system_engine") @@ -137,10 +204,3 @@ def check_engine_status(cursor, engine_name, status): check_engine_status(cursor, engine_name, "Running") cursor.execute(f"STOP ENGINE {engine_name}") check_engine_status(cursor, engine_name, "Stopped") - - -@mark.xdist_group(name="system_engine") -def test_select_one(connection_system_engine): - """SELECT statements are supported""" - with connection_system_engine.cursor() as cursor: - cursor.execute("SELECT 1") diff --git a/tests/integration/dbapi/utils.py b/tests/integration/dbapi/utils.py new file mode 100644 index 00000000000..603c58a7d7b --- /dev/null +++ b/tests/integration/dbapi/utils.py @@ -0,0 +1,9 @@ +from typing import Any + + +def assert_deep_eq(got: Any, expected: Any, msg: str) -> bool: + if type(got) == list and type(expected) == list: + all([assert_deep_eq(f, s, msg) for f, s in zip(got, expected)]) + assert ( + type(got) == type(expected) and got == expected + ), f"{msg}: {got}(got) != {expected}(expected)" diff --git a/tests/unit/async_db/conftest.py b/tests/unit/async_db/conftest.py index fdca43ab438..8be96d4ccf5 100644 --- a/tests/unit/async_db/conftest.py +++ b/tests/unit/async_db/conftest.py @@ -2,29 +2,36 @@ from typing import Dict from pytest import fixture -from pytest_asyncio import fixture as asyncio_fixture from firebolt.async_db import ARRAY, DECIMAL, Connection, Cursor, connect +from firebolt.client.auth import Auth from firebolt.common.settings import Settings from tests.unit.db_conftest import * # noqa -@asyncio_fixture -async def connection(settings: Settings, db_name: str) -> Connection: +@fixture +async def connection( + server: str, + db_name: str, + auth: Auth, + engine_name: str, + account_name: str, + mock_connection_flow: Callable, +) -> Connection: + mock_connection_flow() async with ( await connect( - engine_url=settings.server, + engine_name=engine_name, database=db_name, - username="u", - password="p", - account_name=settings.account_name, - api_endpoint=settings.server, + auth=auth, + account_name=account_name, + api_endpoint=server, ) ) as connection: yield connection -@asyncio_fixture +@fixture async def cursor(connection: Connection, settings: Settings) -> Cursor: return connection.cursor() diff --git a/tests/unit/async_db/test_connection.py b/tests/unit/async_db/test_connection.py index 6da6cf46445..ac63f1f4aa3 100644 --- a/tests/unit/async_db/test_connection.py +++ b/tests/unit/async_db/test_connection.py @@ -1,24 +1,21 @@ -from re import Pattern from typing import Callable, List from unittest.mock import patch -from httpx import codes from pyfakefs.fake_filesystem_unittest import Patcher from pytest import mark, raises from pytest_httpx import HTTPXMock from firebolt.async_db.connection import Connection, connect -from firebolt.client.auth import Auth, Token, UsernamePassword +from firebolt.client.auth import Auth, ClientCredentials from firebolt.common._types import ColType from firebolt.common.settings import Settings from firebolt.utils.exception import ( - AccountNotFoundError, ConfigurationError, ConnectionClosedError, - FireboltEngineError, + EngineNotRunningError, + InterfaceError, ) from firebolt.utils.token_storage import TokenSecureStorage -from firebolt.utils.urls import ACCOUNT_ENGINE_ID_BY_NAME_URL async def test_closed_connection(connection: Connection) -> None: @@ -52,194 +49,125 @@ async def test_cursors_closed_on_close(connection: Connection) -> None: async def test_cursor_initialized( settings: Settings, - db_name: str, - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, - query_url: str, + mock_query: Callable, + connection: Connection, python_query_data: List[List[ColType]], ) -> None: """Connection initialized its cursors properly.""" - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() - for url in (settings.server, f"https://{settings.server}"): - async with ( - await connect( - engine_url=url, - database=db_name, - username="u", - password="p", - api_endpoint=settings.server, - ) - ) as connection: - cursor = connection.cursor() - assert ( - cursor.connection == connection - ), "Invalid cursor connection attribute." - assert ( - cursor._client == connection._client - ), "Invalid cursor _client attribute" + cursor = connection.cursor() + assert cursor.connection == connection, "Invalid cursor connection attribute." + assert cursor._client == connection._client, "Invalid cursor _client attribute" - assert await cursor.execute("select*") == len(python_query_data) + assert await cursor.execute("select*") == len(python_query_data) - cursor.close() - assert ( - cursor not in connection._cursors - ), "Cursor wasn't removed from connection after close." + cursor.close() + assert ( + cursor not in connection._cursors + ), "Cursor wasn't removed from connection after close." async def test_connect_empty_parameters(): with raises(ConfigurationError): - async with await connect(engine_url="engine_url"): - pass - - -async def test_connect_access_token( - settings: Settings, - db_name: str, - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - check_token_callback: Callable, - query_url: str, - python_query_data: List[List[ColType]], - access_token: str, -): - httpx_mock.add_callback(check_token_callback, url=query_url) - async with ( - await connect( - engine_url=settings.server, - database=db_name, - access_token=access_token, - api_endpoint=settings.server, - ) - ) as connection: - cursor = connection.cursor() - assert await cursor.execute("select*") == -1 - - with raises(ConfigurationError): - async with await connect(engine_url="engine_url", database="database"): - pass - - with raises(ConfigurationError): - async with await connect( - engine_url="engine_url", - database="database", - username="username", - password="password", - access_token="access_token", - ): + async with await connect(engine_name="engine_name"): pass async def test_connect_engine_name( - settings: Settings, db_name: str, + account_name: str, + engine_name: str, + auth: Auth, + server: str, + python_query_data: List[List[ColType]], httpx_mock: HTTPXMock, - auth_callback: Callable, + mock_query: Callable, + system_engine_query_url: str, + get_engine_url_not_running_callback: Callable, + get_engine_url_invalid_db_callback: Callable, auth_url: str, - query_callback: Callable, - query_url: str, - account_id_url: Pattern, + check_credentials_callback: Callable, + get_system_engine_url: str, + get_system_engine_callback: Callable, + get_engine_url_callback: Callable, + account_id_url: str, account_id_callback: Callable, - engine_id: str, - get_engine_url_by_id_url: str, - get_engine_url_by_id_callback: Callable, - python_query_data: List[List[ColType]], - account_id: str, ): """connect properly handles engine_name""" - with raises(ConfigurationError): - async with await connect( - engine_url="engine_url", - engine_name="engine_name", - database="db", - username="username", - password="password", - account_name="account_name", - ): - pass - - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(get_system_engine_callback, url=get_system_engine_url) httpx_mock.add_callback(account_id_callback, url=account_id_url) - httpx_mock.add_callback(get_engine_url_by_id_callback, url=get_engine_url_by_id_url) - engine_name = settings.server.split(".")[0] + mock_query() - # Mock engine id lookup error - httpx_mock.add_response( - url=f"https://{settings.server}" - + ACCOUNT_ENGINE_ID_BY_NAME_URL.format(account_id=account_id) - + f"?engine_name={engine_name}", - status_code=codes.NOT_FOUND, - ) - - with raises(FireboltEngineError): - async with await connect( - database="db", - username="username", - password="password", - engine_name=engine_name, - account_name=settings.account_name, - api_endpoint=settings.server, - ): - pass + for callback, err_cls in ( + (get_engine_url_invalid_db_callback, InterfaceError), + (get_engine_url_not_running_callback, EngineNotRunningError), + ): + httpx_mock.add_callback(callback, url=system_engine_query_url) + with raises(err_cls): + async with await connect( + database=db_name, + auth=auth, + engine_name=engine_name, + account_name=account_name, + api_endpoint=server, + ): + pass - # Mock engine id lookup by name - httpx_mock.add_response( - url=f"https://{settings.server}" - + ACCOUNT_ENGINE_ID_BY_NAME_URL.format(account_id=account_id) - + f"?engine_name={engine_name}", - status_code=codes.OK, - json={"engine_id": {"engine_id": engine_id}}, - ) + httpx_mock.add_callback(get_engine_url_callback, url=system_engine_query_url) async with await connect( engine_name=engine_name, database=db_name, - username="u", - password="p", - account_name=settings.account_name, - api_endpoint=settings.server, + auth=auth, + account_name=account_name, + api_endpoint=server, ) as connection: assert await connection.cursor().execute("select*") == len(python_query_data) -async def test_connect_default_engine( - settings: Settings, +async def test_connect_database( db_name: str, - httpx_mock: HTTPXMock, - auth_callback: Callable, auth_url: str, - query_callback: Callable, - query_url: str, - account_id_url: Pattern, - account_id_callback: Callable, - engine_by_db_url: str, + server: str, + auth: Auth, + account_name: str, python_query_data: List[List[ColType]], + httpx_mock: HTTPXMock, + query_callback: str, + check_credentials_callback: Callable, + system_engine_query_url: str, + system_engine_no_db_query_url: str, + get_system_engine_url: str, + get_system_engine_callback: Callable, + account_id_url: str, + account_id_callback: Callable, ): - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(get_system_engine_callback, url=get_system_engine_url) + httpx_mock.add_callback(query_callback, url=system_engine_no_db_query_url) + httpx_mock.add_callback(account_id_callback, url=account_id_url) + async with await connect( + database=None, + auth=auth, + account_name=account_name, + api_endpoint=server, + ) as connection: + await connection.cursor().execute("select*") + + httpx_mock.reset(True) + httpx_mock.add_callback(get_system_engine_callback, url=get_system_engine_url) + httpx_mock.add_callback(query_callback, url=system_engine_query_url) httpx_mock.add_callback(account_id_callback, url=account_id_url) - engine_by_db_url = f"{engine_by_db_url}?database_name={db_name}" - - httpx_mock.add_response( - url=engine_by_db_url, - status_code=codes.OK, - json={ - "engine_url": settings.server, - }, - ) + async with await connect( database=db_name, - username="u", - password="p", - account_name=settings.account_name, - api_endpoint=settings.server, + auth=auth, + account_name=account_name, + api_endpoint=server, ) as connection: assert await connection.cursor().execute("select*") == len(python_query_data) @@ -255,129 +183,67 @@ async def test_connection_commit(connection: Connection): @mark.nofakefs async def test_connection_token_caching( - settings: Settings, db_name: str, - httpx_mock: HTTPXMock, - check_credentials_callback: Callable, - auth_url: str, - query_callback: Callable, - query_url: str, - python_query_data: List[List[ColType]], + server: str, access_token: str, + client_id: str, + client_secret: str, + engine_name: str, + account_name: str, + python_query_data: List[List[ColType]], + mock_connection_flow: Callable, + mock_query: Callable, ) -> None: - httpx_mock.add_callback(check_credentials_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_connection_flow() + mock_query() with Patcher(): async with await connect( database=db_name, - username=settings.user, - password=settings.password, - engine_url=settings.server, - account_name=settings.account_name, - api_endpoint=settings.server, - use_token_cache=True, + auth=ClientCredentials(client_id, client_secret, use_token_cache=True), + engine_name=engine_name, + account_name=account_name, + api_endpoint=server, ) as connection: assert await connection.cursor().execute("select*") == len( python_query_data ) - ts = TokenSecureStorage(username=settings.user, password=settings.password) + ts = TokenSecureStorage(username=client_id, password=client_secret) assert ts.get_cached_token() == access_token, "Invalid token value cached" # Do the same, but with use_token_cache=False with Patcher(): async with await connect( database=db_name, - username=settings.user, - password=settings.password, - engine_url=settings.server, - account_name=settings.account_name, - api_endpoint=settings.server, - use_token_cache=False, + auth=ClientCredentials(client_id, client_secret, use_token_cache=False), + engine_name=engine_name, + account_name=account_name, + api_endpoint=server, ) as connection: assert await connection.cursor().execute("select*") == len( python_query_data ) - ts = TokenSecureStorage(username=settings.user, password=settings.password) + ts = TokenSecureStorage(username=client_id, password=client_secret) assert ( ts.get_cached_token() is None ), "Token is cached even though caching is disabled" -async def test_connect_with_auth( - httpx_mock: HTTPXMock, - settings: Settings, +async def test_connect_with_user_agent( + engine_name: str, + account_name: str, + server: str, db_name: str, - check_credentials_callback: Callable, - auth_url: str, - query_callback: Callable, - query_url: str, - access_token: str, -) -> None: - httpx_mock.add_callback(check_credentials_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) - - for auth in ( - UsernamePassword( - settings.user, - settings.password, - use_token_cache=False, - ), - Token(access_token), - ): - async with await connect( - auth=auth, - database=db_name, - engine_url=settings.server, - account_name=settings.account_name, - api_endpoint=settings.server, - ) as connection: - await connection.cursor().execute("select*") - - -async def test_connect_account_name( - httpx_mock: HTTPXMock, auth: Auth, - settings: Settings, - db_name: str, - auth_url: str, - check_credentials_callback: Callable, - account_id_url: Pattern, - account_id_callback: Callable, -): - httpx_mock.add_callback(check_credentials_callback, url=auth_url) - httpx_mock.add_callback(account_id_callback, url=account_id_url) - - with raises(AccountNotFoundError): - async with await connect( - auth=auth, - database=db_name, - engine_url=settings.server, - account_name="invalid", - api_endpoint=settings.server, - ): - pass - - async with await connect( - auth=auth, - database=db_name, - engine_url=settings.server, - account_name=settings.account_name, - api_endpoint=settings.server, - ): - pass - - -async def test_connect_with_user_agent( + access_token: str, httpx_mock: HTTPXMock, - settings: Settings, - db_name: str, query_callback: Callable, query_url: str, - access_token: str, + mock_connection_flow: Callable, ) -> None: with patch("firebolt.async_db.connection.get_user_agent_header") as ut: ut.return_value = "MyConnector/1.0 DriverA/1.1" + mock_connection_flow() httpx_mock.add_callback( query_callback, url=query_url, @@ -385,40 +251,45 @@ async def test_connect_with_user_agent( ) async with await connect( - auth=Token(access_token), + auth=auth, database=db_name, - engine_url=settings.server, - account_name=settings.account_name, - api_endpoint=settings.server, + engine_name=engine_name, + account_name=account_name, + api_endpoint=server, additional_parameters={ "user_clients": [("MyConnector", "1.0")], "user_drivers": [("DriverA", "1.1")], }, ) as connection: await connection.cursor().execute("select*") - ut.assert_called_once_with([("DriverA", "1.1")], [("MyConnector", "1.0")]) + ut.assert_called_with([("DriverA", "1.1")], [("MyConnector", "1.0")]) async def test_connect_no_user_agent( - httpx_mock: HTTPXMock, - settings: Settings, + engine_name: str, + account_name: str, + server: str, db_name: str, + auth: Auth, + access_token: str, + httpx_mock: HTTPXMock, query_callback: Callable, query_url: str, - access_token: str, + mock_connection_flow: Callable, ) -> None: with patch("firebolt.async_db.connection.get_user_agent_header") as ut: ut.return_value = "Python/3.0" + mock_connection_flow() httpx_mock.add_callback( query_callback, url=query_url, match_headers={"User-Agent": "Python/3.0"} ) async with await connect( - auth=Token(access_token), + auth=auth, database=db_name, - engine_url=settings.server, - account_name=settings.account_name, - api_endpoint=settings.server, + engine_name=engine_name, + account_name=account_name, + api_endpoint=server, ) as connection: await connection.cursor().execute("select*") - ut.assert_called_once_with([], []) + ut.assert_called_with([], []) diff --git a/tests/unit/async_db/test_cursor.py b/tests/unit/async_db/test_cursor.py index 64f13d4318f..c9e02e69a64 100644 --- a/tests/unit/async_db/test_cursor.py +++ b/tests/unit/async_db/test_cursor.py @@ -13,7 +13,6 @@ CursorClosedError, DataError, EngineNotRunningError, - FireboltDatabaseError, OperationalError, QueryNotRunError, ) @@ -22,15 +21,12 @@ async def test_cursor_state( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, + mock_query: Callable, query_url: str, cursor: Cursor, ): """Cursor state changes depend on the operations performed with it.""" - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() assert cursor._state == CursorState.NONE @@ -95,11 +91,7 @@ async def test_closed_cursor(cursor: Cursor): async def test_cursor_no_query( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, - query_url: str, - server_side_async_id: str, + mock_query: Callable, cursor: Cursor, ): """Some cursor methods are unavailable until a query is run.""" @@ -109,8 +101,7 @@ async def test_cursor_no_query( "fetchall", ) - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() for amethod in async_methods: with raises(QueryNotRunError): @@ -145,12 +136,8 @@ async def test_cursor_no_query( async def test_cursor_execute( - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, - insert_query_callback: Callable, - query_url: str, + mock_query: Callable, + mock_insert_query: Callable, cursor: Cursor, python_query_description: List[Column], python_query_data: List[List[ColType]], @@ -167,8 +154,7 @@ async def test_cursor_execute( ), ): # Query with json output - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() assert await query() == len( python_query_data ), f"Invalid row count returned for {message}." @@ -189,11 +175,8 @@ async def test_cursor_execute( await cursor.fetchone() is None ), f"Non-empty fetchone after all data received {message}." - httpx_mock.reset(True) - # Query with empty output - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(insert_query_callback, url=query_url) + mock_insert_query() assert await query() == -1, f"Invalid row count for insert using {message}." assert ( cursor.rowcount == -1 @@ -205,12 +188,10 @@ async def test_cursor_execute( async def test_cursor_execute_error( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, query_url: str, - get_engines_url: str, - get_databases_url: str, cursor: Cursor, + get_engine_url_not_running_callback: Callable, + system_engine_query_url: str, ): """Cursor handles all types of errors properly.""" for query, message in ( @@ -223,8 +204,6 @@ async def test_cursor_execute_error( "server-side synchronous executemany()", ), ): - httpx_mock.add_callback(auth_callback, url=auth_url) - # Internal httpx error def http_error(*args, **kwargs): raise StreamError("httpx error") @@ -261,32 +240,15 @@ def http_error(*args, **kwargs): str(excinfo.value) == "Error executing query:\nQuery error message" ), f"Invalid authentication error message for {message}." - # Database does not exist error - httpx_mock.add_response( - status_code=codes.FORBIDDEN, - content="Query error message", - url=query_url, - ) - httpx_mock.add_response( - json={"edges": []}, - url=get_databases_url + "?filter.name_contains=database", - ) - with raises(FireboltDatabaseError) as excinfo: - await query() - assert cursor._state == CursorState.ERROR - # Engine is not running error httpx_mock.add_response( status_code=codes.SERVICE_UNAVAILABLE, content="Query error message", url=query_url, ) - httpx_mock.add_response( - json={"edges": []}, - url=( - get_engines_url + "?filter.name_contains=api_dev" - "&filter.current_status_eq=ENGINE_STATUS_RUNNING_REVISION_SERVING" - ), + httpx_mock.add_callback( + get_engine_url_not_running_callback, + url=system_engine_query_url, ) with raises(EngineNotRunningError) as excinfo: await query() @@ -297,8 +259,6 @@ def http_error(*args, **kwargs): async def test_cursor_server_side_async_execute_errors( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, query_with_params_url: str, server_side_async_missing_id_callback: Callable, insert_query_callback: str, @@ -318,8 +278,6 @@ async def test_cursor_server_side_async_execute_errors( "server-side asynchronous executemany()", ), ): - # Empty server-side asynchronous execution return. - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback(insert_query_callback, url=query_with_params_url) with raises(OperationalError) as excinfo: await query("SELECT * FROM t") @@ -328,7 +286,6 @@ async def test_cursor_server_side_async_execute_errors( assert str(excinfo.value) == ("No response to asynchronous query.") # Missing query_id from server-side asynchronous execution. - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback( server_side_async_missing_id_callback, url=query_with_params_url ) @@ -341,7 +298,6 @@ async def test_cursor_server_side_async_execute_errors( ) # Multi-statement queries are not possible with async_execution error. - httpx_mock.add_callback(auth_callback, url=auth_url) with raises(AsyncExecutionUnavailableError) as excinfo: await query("SELECT * FROM t; SELECT * FROM s") @@ -374,23 +330,17 @@ async def test_cursor_server_side_async_execute_errors( ), f"use_standard_sql=0 was allowed for server-side asynchronous queries on {message}." # Have to reauth or next execute fails. Not sure why. - httpx_mock.add_callback(auth_callback, url=auth_url) await cursor.execute("set use_standard_sql=1") httpx_mock.reset(True) async def test_cursor_fetchone( - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, - insert_query_callback: Callable, - query_url: str, + mock_query: Callable, + mock_insert_query: Callable, cursor: Cursor, ): """cursor fetchone fetches single row in correct order. If no rows, returns None.""" - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() await cursor.execute("sql") @@ -405,27 +355,22 @@ async def test_cursor_fetchone( await cursor.fetchone() is None ), "fetchone should return None when no rows left to fetch." - httpx_mock.add_callback(insert_query_callback, url=query_url) + mock_insert_query() await cursor.execute("sql") with raises(DataError): await cursor.fetchone() async def test_cursor_fetchmany( - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, - insert_query_callback: Callable, - query_url: str, + mock_query: Callable, + mock_insert_query: Callable, cursor: Cursor, ): """ Cursor's fetchmany fetches the provided amount of rows, or arraysize by default. If not enough rows left, returns less or None if there are no rows. """ - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() await cursor.execute("sql") @@ -466,24 +411,19 @@ async def test_cursor_fetchmany( len(await cursor.fetchmany()) == 0 ), "fetchmany should return empty result set when no rows remain to fetch" - httpx_mock.add_callback(insert_query_callback, url=query_url) + mock_insert_query() await cursor.execute("sql") with raises(DataError): await cursor.fetchmany() async def test_cursor_fetchall( - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, - insert_query_callback: Callable, - query_url: str, + mock_query: Callable, + mock_insert_query: Callable, cursor: Cursor, ): """cursor fetchall fetches all rows remaining after last query.""" - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() await cursor.execute("sql") @@ -500,28 +440,23 @@ async def test_cursor_fetchall( len(await cursor.fetchall()) == 0 ), "fetchmany should return empty result set when no rows remain to fetch" - httpx_mock.add_callback(insert_query_callback, url=query_url) + mock_insert_query() await cursor.execute("sql") with raises(DataError): await cursor.fetchall() async def test_cursor_multi_statement( - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, - insert_query_callback: Callable, - query_url: str, + mock_query: Callable, + mock_insert_query: Callable, cursor: Cursor, python_query_description: List[Column], python_query_data: List[List[ColType]], ): """executemany with multiple parameter sets is not supported.""" - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) - httpx_mock.add_callback(insert_query_callback, url=query_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() + mock_insert_query() + mock_query() rc = await cursor.execute( "select * from t; insert into t values (1, 2); select * from t" @@ -570,14 +505,11 @@ async def test_cursor_multi_statement( async def test_cursor_set_statements( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, select_one_query_callback: Callable, set_query_url: str, cursor: Cursor, ): """cursor correctly parses and processes set statements.""" - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback(select_one_query_callback, url=f"{set_query_url}&a=b") assert len(cursor._set_parameters) == 0 @@ -639,8 +571,6 @@ async def test_cursor_set_statements( async def test_cursor_set_parameters_sent( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, set_query_url: str, query_url: str, query_with_params_callback: Callable, @@ -649,8 +579,6 @@ async def test_cursor_set_parameters_sent( set_params: Dict, ): """Cursor passes provided set parameters to engine.""" - httpx_mock.add_callback(auth_callback, url=auth_url) - params = "" for p, v in set_params.items(): @@ -666,16 +594,11 @@ async def test_cursor_set_parameters_sent( async def test_cursor_skip_parse( - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_url: str, - query_callback: Callable, + mock_query: Callable, cursor: Cursor, ): """Cursor doesn't process a query if skip_parsing is provided.""" - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() with patch("firebolt.async_db.cursor.split_format_sql") as split_format_sql_mock: await cursor.execute("non-an-actual-sql") @@ -688,8 +611,6 @@ async def test_cursor_skip_parse( async def test_cursor_server_side_async_execute( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, server_side_async_id_callback: Callable, server_side_async_id: Callable, query_with_params_url: str, @@ -712,7 +633,6 @@ async def test_cursor_server_side_async_execute( ), ): # Query with json output - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback( server_side_async_id_callback, url=query_with_params_url ) @@ -730,8 +650,6 @@ async def test_cursor_server_side_async_execute( async def test_cursor_server_side_async_cancel( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, server_side_async_cancel_callback: Callable, server_side_async_id: Callable, query_with_params_url: str, @@ -743,7 +661,6 @@ async def test_cursor_server_side_async_cancel( """ # Query with json output - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback( server_side_async_cancel_callback, url=query_with_params_url ) @@ -752,8 +669,6 @@ async def test_cursor_server_side_async_cancel( async def test_cursor_server_side_async_get_status_completed( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, server_side_async_get_status_callback: Callable, server_side_async_id: Callable, query_with_params_url: str, @@ -765,7 +680,6 @@ async def test_cursor_server_side_async_get_status_completed( """ # Query with json output - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback( server_side_async_get_status_callback, url=query_with_params_url ) @@ -775,8 +689,6 @@ async def test_cursor_server_side_async_get_status_completed( async def test_cursor_server_side_async_get_status_not_yet_available( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, server_side_async_get_status_not_yet_availabe_callback: Callable, server_side_async_id: Callable, query_with_params_url: str, @@ -788,7 +700,6 @@ async def test_cursor_server_side_async_get_status_not_yet_available( """ # Query with json output - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback( server_side_async_get_status_not_yet_availabe_callback, url=query_with_params_url, @@ -799,15 +710,12 @@ async def test_cursor_server_side_async_get_status_not_yet_available( async def test_cursor_server_side_async_get_status_error( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, server_side_async_get_status_error: Callable, server_side_async_id: Callable, query_with_params_url: str, cursor: Cursor, ): """ """ - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback( server_side_async_get_status_error, url=query_with_params_url ) diff --git a/tests/unit/client/auth/test_auth.py b/tests/unit/client/auth/test_auth.py index 8827546d170..94203b0875a 100644 --- a/tests/unit/client/auth/test_auth.py +++ b/tests/unit/client/auth/test_auth.py @@ -12,7 +12,7 @@ def test_auth_refresh_on_expiration( - httpx_mock: HTTPXMock, test_token: str, test_token2: str + httpx_mock: HTTPXMock, access_token: str, access_token_2: str ) -> None: """Auth refreshes the token on expiration.""" url = "https://host" @@ -29,19 +29,19 @@ def inner(self): auth = Auth(use_token_cache=False) # Get token for the first time - auth.get_new_token_generator = MethodType(set_token(test_token), auth) + auth.get_new_token_generator = MethodType(set_token(access_token), auth) execute_generator_requests(auth.auth_flow(Request("GET", url))) - assert auth.token == test_token, "invalid access token" + assert auth.token == access_token, "invalid access token" assert auth.expired # Refresh token - auth.get_new_token_generator = MethodType(set_token(test_token2), auth) + auth.get_new_token_generator = MethodType(set_token(access_token_2), auth) execute_generator_requests(auth.auth_flow(Request("GET", url))) - assert auth.token == test_token2, "Expired access token was not updated." + assert auth.token == access_token_2, "Expired access token was not updated." def test_auth_uses_same_token_if_valid( - httpx_mock: HTTPXMock, test_token: str, test_token2: str + httpx_mock: HTTPXMock, access_token: str, access_token_2: str ) -> None: """Auth reuses the token until it's expired.""" url = "https://host" @@ -58,37 +58,37 @@ def inner(self): auth = Auth(use_token_cache=False) # Get token for the first time - auth.get_new_token_generator = MethodType(set_token(test_token), auth) + auth.get_new_token_generator = MethodType(set_token(access_token), auth) execute_generator_requests(auth.auth_flow(Request("GET", url))) - assert auth.token == test_token, "invalid access token" + assert auth.token == access_token, "invalid access token" assert not auth.expired # Refresh token - auth.get_new_token_generator = MethodType(set_token(test_token2), auth) + auth.get_new_token_generator = MethodType(set_token(access_token_2), auth) execute_generator_requests(auth.auth_flow(Request("GET", url))) - assert auth.token == test_token, "Should not update token until it expires." + assert auth.token == access_token, "Should not update token until it expires." -def test_auth_adds_header(test_token: str) -> None: +def test_auth_adds_header(access_token: str) -> None: """Auth adds required authentication headers to httpx.Request.""" auth = Auth(use_token_cache=False) - auth._token = test_token + auth._token = access_token auth._expires = 2**32 flow = auth.auth_flow(Request("get", "")) request = next(flow) assert "authorization" in request.headers, "missing authorization header" assert ( - request.headers["authorization"] == f"Bearer {test_token}" + request.headers["authorization"] == f"Bearer {access_token}" ), "missing authorization header" @mark.nofakefs def test_auth_token_storage( httpx_mock: HTTPXMock, - test_username: str, - test_password: str, - test_token, + client_id: str, + client_secret: str, + access_token: str, ) -> None: # Mock auth flow def set_token(token: str) -> callable: @@ -104,26 +104,26 @@ def inner(self): with Patcher(), patch( "firebolt.client.auth.base.Auth._token_storage", new_callable=PropertyMock, - return_value=TokenSecureStorage(test_username, test_password), + return_value=TokenSecureStorage(client_id, client_secret), ): auth = Auth(use_token_cache=True) # Get token - auth.get_new_token_generator = MethodType(set_token(test_token), auth) + auth.get_new_token_generator = MethodType(set_token(access_token), auth) execute_generator_requests(auth.auth_flow(Request("GET", url))) - st = TokenSecureStorage(test_username, test_password) - assert st.get_cached_token() == test_token, "Invalid token value cached" + st = TokenSecureStorage(client_id, client_secret) + assert st.get_cached_token() == access_token, "Invalid token value cached" with Patcher(), patch( "firebolt.client.auth.base.Auth._token_storage", new_callable=PropertyMock, - return_value=TokenSecureStorage(test_username, test_password), + return_value=TokenSecureStorage(client_id, client_secret), ): auth = Auth(use_token_cache=False) # Get token - auth.get_new_token_generator = MethodType(set_token(test_token), auth) + auth.get_new_token_generator = MethodType(set_token(access_token), auth) execute_generator_requests(auth.auth_flow(Request("GET", url))) - st = TokenSecureStorage(test_username, test_password) + st = TokenSecureStorage(client_id, client_secret) assert ( st.get_cached_token() is None ), "Token cached even though caching is disabled" diff --git a/tests/unit/client/auth/test_auth_async.py b/tests/unit/client/auth/test_auth_async.py index 93990283b9e..abd0fd2cc36 100644 --- a/tests/unit/client/auth/test_auth_async.py +++ b/tests/unit/client/auth/test_auth_async.py @@ -12,7 +12,7 @@ async def test_auth_refresh_on_expiration( - httpx_mock: HTTPXMock, test_token: str, test_token2: str + httpx_mock: HTTPXMock, access_token: str, access_token_2: str ) -> None: """Auth refreshes the token on expiration.""" url = "https://host" @@ -29,19 +29,19 @@ def inner(self): auth = Auth(use_token_cache=False) # Get token for the first time - auth.get_new_token_generator = MethodType(set_token(test_token), auth) + auth.get_new_token_generator = MethodType(set_token(access_token), auth) await async_execute_generator_requests(auth.async_auth_flow(Request("GET", url))) - assert auth.token == test_token, "invalid access token" + assert auth.token == access_token, "invalid access token" assert auth.expired # Refresh token - auth.get_new_token_generator = MethodType(set_token(test_token2), auth) + auth.get_new_token_generator = MethodType(set_token(access_token_2), auth) await async_execute_generator_requests(auth.async_auth_flow(Request("GET", url))) - assert auth.token == test_token2, "expired access token was not updated" + assert auth.token == access_token_2, "expired access token was not updated" async def test_auth_uses_same_token_if_valid( - httpx_mock: HTTPXMock, test_token: str, test_token2: str + httpx_mock: HTTPXMock, access_token: str, access_token_2: str ) -> None: """Auth reuses the token until it's expired.""" url = "https://host" @@ -58,40 +58,40 @@ def inner(self): auth = Auth(use_token_cache=False) # Get token for the first time - auth.get_new_token_generator = MethodType(set_token(test_token), auth) + auth.get_new_token_generator = MethodType(set_token(access_token), auth) await async_execute_generator_requests( auth.async_auth_flow(Request("GET", "https://host")) ) - assert auth.token == test_token, "invalid access token" + assert auth.token == access_token, "invalid access token" assert not auth.expired - auth.get_new_token_generator = MethodType(set_token(test_token2), auth) + auth.get_new_token_generator = MethodType(set_token(access_token_2), auth) await async_execute_generator_requests( auth.async_auth_flow(Request("GET", "https://host")) ) - assert auth.token == test_token, "shoud not update token until it expires" + assert auth.token == access_token, "shoud not update token until it expires" -async def test_auth_adds_header(test_token: str) -> None: +async def test_auth_adds_header(access_token: str) -> None: """Auth adds required authentication headers to httpx.Request.""" auth = Auth(use_token_cache=False) - auth._token = test_token + auth._token = access_token auth._expires = 2**32 flow = auth.async_auth_flow(Request("get", "")) request = await flow.__anext__() assert "authorization" in request.headers, "missing authorization header" assert ( - request.headers["authorization"] == f"Bearer {test_token}" + request.headers["authorization"] == f"Bearer {access_token}" ), "missing authorization header" @mark.nofakefs async def test_auth_token_storage( httpx_mock: HTTPXMock, - test_username: str, - test_password: str, - test_token, + client_id: str, + client_secret: str, + access_token, ) -> None: # Mock auth flow def set_token(token: str) -> callable: @@ -107,30 +107,30 @@ def inner(self): with Patcher(), patch( "firebolt.client.auth.base.Auth._token_storage", new_callable=PropertyMock, - return_value=TokenSecureStorage(test_username, test_password), + return_value=TokenSecureStorage(client_id, client_secret), ): auth = Auth(use_token_cache=True) # Get token - auth.get_new_token_generator = MethodType(set_token(test_token), auth) + auth.get_new_token_generator = MethodType(set_token(access_token), auth) await async_execute_generator_requests( auth.async_auth_flow(Request("GET", url)) ) - st = TokenSecureStorage(test_username, test_password) - assert st.get_cached_token() == test_token, "Invalid token value cached" + st = TokenSecureStorage(client_id, client_secret) + assert st.get_cached_token() == access_token, "Invalid token value cached" with Patcher(), patch( "firebolt.client.auth.base.Auth._token_storage", new_callable=PropertyMock, - return_value=TokenSecureStorage(test_username, test_password), + return_value=TokenSecureStorage(client_id, client_secret), ): auth = Auth(use_token_cache=False) # Get token - auth.get_new_token_generator = MethodType(set_token(test_token), auth) + auth.get_new_token_generator = MethodType(set_token(access_token), auth) await async_execute_generator_requests( auth.async_auth_flow(Request("GET", url)) ) - st = TokenSecureStorage(test_username, test_password) + st = TokenSecureStorage(client_id, client_secret) assert ( st.get_cached_token() is None ), "Token cached even though caching is disabled" diff --git a/tests/unit/client/auth/test_request_auth.py b/tests/unit/client/auth/test_request_auth.py index d0e4ef87e54..bcb028d500a 100644 --- a/tests/unit/client/auth/test_request_auth.py +++ b/tests/unit/client/auth/test_request_auth.py @@ -2,56 +2,36 @@ import pytest from httpx import StreamError, codes -from pytest import mark from pytest_httpx import HTTPXMock from pytest_mock import MockerFixture -from firebolt.client.auth import Auth, ServiceAccount, UsernamePassword +from firebolt.client.auth import ClientCredentials from firebolt.utils.exception import AuthenticationError from tests.unit.util import execute_generator_requests def test_auth_service_account( - httpx_mock: HTTPXMock, - mocker: MockerFixture, - check_service_credentials_callback: typing.Callable, - mock_service_id: str, - mock_service_secret: str, - test_token: str, -): - """Auth can retrieve token and expiration values.""" - httpx_mock.add_callback(check_service_credentials_callback) - - mocker.patch("firebolt.client.auth.request_auth_base.time", return_value=0) - auth = ServiceAccount(mock_service_id, mock_service_secret) - execute_generator_requests(auth.get_new_token_generator()) - assert auth.token == test_token, "invalid access token" - assert auth._expires == 2**32, "invalid expiration value" - - -def test_auth_username_password( httpx_mock: HTTPXMock, mocker: MockerFixture, check_credentials_callback: typing.Callable, - test_username, - test_password, - test_token, + client_id: str, + client_secret: str, + access_token: str, ): """Auth can retrieve token and expiration values.""" httpx_mock.add_callback(check_credentials_callback) mocker.patch("firebolt.client.auth.request_auth_base.time", return_value=0) - auth = UsernamePassword(test_username, test_password) + auth = ClientCredentials(client_id, client_secret) execute_generator_requests(auth.get_new_token_generator()) - assert auth.token == test_token, "invalid access token" + assert auth.token == access_token, "invalid access token" assert auth._expires == 2**32, "invalid expiration value" -@mark.parametrize("auth_class", [UsernamePassword, ServiceAccount]) -def test_auth_error_handling(httpx_mock: HTTPXMock, auth_class: Auth): +def test_auth_error_handling(httpx_mock: HTTPXMock, client_id: str, client_secret: str): """Auth handles various errors properly.""" for api_endpoint in ("https://host", "host"): - auth = auth_class("user", "password", use_token_cache=False) + auth = ClientCredentials(client_id, client_secret, use_token_cache=False) # Internal httpx error def http_error(*args, **kwargs): diff --git a/tests/unit/client/auth/test_token.py b/tests/unit/client/auth/test_token.py deleted file mode 100644 index f7cc2df07da..00000000000 --- a/tests/unit/client/auth/test_token.py +++ /dev/null @@ -1,40 +0,0 @@ -from typing import Callable - -from httpx import Request, Response, codes -from pytest import raises -from pytest_httpx import HTTPXMock - -from firebolt.client.auth import Token -from firebolt.utils.exception import AuthorizationError -from tests.unit.util import execute_generator_requests - - -def test_token_happy_path( - httpx_mock: HTTPXMock, test_token: str, check_token_callback: Callable -) -> None: - """Validate that provided token is added to request.""" - httpx_mock.add_callback(check_token_callback) - - auth = Token(test_token) - execute_generator_requests(auth.auth_flow(Request("GET", "https://host"))) - - -def test_token_invalid(httpx_mock: HTTPXMock) -> None: - """Authorization error raised when token is invalid.""" - - def authorization_error(*args, **kwargs) -> Response: - return Response(status_code=codes.UNAUTHORIZED) - - httpx_mock.add_callback(authorization_error) - - auth = Token("token") - with raises(AuthorizationError): - execute_generator_requests(auth.auth_flow(Request("GET", "https://host"))) - - -def test_token_expired() -> None: - """Authorization error is raised when token expires.""" - auth = Token("token") - auth._expires = 0 - with raises(AuthorizationError): - execute_generator_requests(auth.auth_flow(Request("GET", "https://host"))) diff --git a/tests/unit/client/conftest.py b/tests/unit/client/conftest.py deleted file mode 100644 index 8fe56a0ec33..00000000000 --- a/tests/unit/client/conftest.py +++ /dev/null @@ -1,86 +0,0 @@ -import json -import typing - -import httpx -from httpx import Response -from pytest import fixture - -from firebolt.common.settings import Settings - - -@fixture -def test_token(access_token: str) -> str: - return access_token - - -@fixture -def test_token2() -> str: - return "test_token2" - - -@fixture -def test_username(settings: Settings) -> str: - return settings.user - - -@fixture -def test_password(settings: Settings) -> str: - return settings.password - - -@fixture -def mock_service_id() -> str: - return "mock_service_id" - - -@fixture -def mock_service_secret() -> str: - return "my_secret" - - -@fixture -def check_credentials_callback( - test_username: str, test_password: str, test_token: str -) -> typing.Callable: - def check_credentials( - request: httpx.Request = None, - **kwargs, - ) -> Response: - assert request, "empty request" - body = json.loads(request.read()) - assert "username" in body, "Missing username" - assert body["username"] == test_username, "Invalid username" - assert "password" in body, "Missing password" - assert body["password"] == test_password, "Invalid password" - - return Response( - status_code=httpx.codes.OK, - json={"expires_in": 2**32, "access_token": test_token}, - ) - - return check_credentials - - -@fixture -def check_service_credentials_callback( - mock_service_id: str, mock_service_secret: str, test_token: str -) -> typing.Callable: - def check_credentials( - request: httpx.Request = None, - **kwargs, - ) -> Response: - assert request, "empty request" - body = request.read().decode("utf-8") - assert "client_id" in body, "Missing id" - assert f"client_id={mock_service_id}" in body, "Invalid id" - assert "client_secret" in body, "Missing secret" - assert f"client_secret={mock_service_secret}" in body, "Invalid secret" - assert "grant_type" in body, "Missing grant_type" - assert "grant_type=client_credentials" in body, "Invalid grant_type" - - return Response( - status_code=httpx.codes.OK, - json={"expires_in": 2**32, "access_token": test_token}, - ) - - return check_credentials diff --git a/tests/unit/client/test_client.py b/tests/unit/client/test_client.py index aba06631178..0db63da2dd2 100644 --- a/tests/unit/client/test_client.py +++ b/tests/unit/client/test_client.py @@ -6,65 +6,65 @@ from pytest import raises from pytest_httpx import HTTPXMock -from firebolt.client import DEFAULT_API_URL, Client -from firebolt.client.auth import Token, UsernamePassword +from firebolt.client import Client +from firebolt.client.auth import Auth, ClientCredentials from firebolt.client.resource_manager_hooks import raise_on_4xx_5xx from firebolt.common import Settings from firebolt.utils.token_storage import TokenSecureStorage -from firebolt.utils.urls import AUTH_URL +from firebolt.utils.urls import AUTH_SERVICE_ACCOUNT_URL from firebolt.utils.util import fix_url_schema def test_client_retry( httpx_mock: HTTPXMock, - test_username: str, - test_password: str, - test_token: str, + auth: Auth, + account_name: str, + access_token: str, ): """ Client retries with new auth token if first attempt fails with unauthorized error. """ - client = Client(auth=UsernamePassword(test_username, test_password)) + with Client(account_name=account_name, auth=auth) as client: - # auth get token - httpx_mock.add_response( - status_code=codes.OK, - json={"expires_in": 2**30, "access_token": test_token}, - ) + # auth get token + httpx_mock.add_response( + status_code=codes.OK, + json={"expires_in": 2**30, "access_token": access_token}, + ) - # client request failed authorization - httpx_mock.add_response( - status_code=codes.UNAUTHORIZED, - ) + # client request failed authorization + httpx_mock.add_response( + status_code=codes.UNAUTHORIZED, + ) - # auth get another token - httpx_mock.add_response( - status_code=codes.OK, - json={"expires_in": 2**30, "access_token": test_token}, - ) + # auth get another token + httpx_mock.add_response( + status_code=codes.OK, + json={"expires_in": 2**30, "access_token": access_token}, + ) - # client request success this time - httpx_mock.add_response( - status_code=codes.OK, - ) + # client request success this time + httpx_mock.add_response( + status_code=codes.OK, + ) - assert ( - client.get("https://url").status_code == codes.OK - ), "request failed with firebolt client" + assert ( + client.get("https://url").status_code == codes.OK + ), "request failed with firebolt client" def test_client_different_auths( httpx_mock: HTTPXMock, check_credentials_callback: Callable, check_token_callback: Callable, - test_username: str, - test_password: str, - test_token: str, + auth: Auth, + account_name: str, + auth_server: str, + server: str, ): """ Client properly handles such auth types: - - tuple(username, password) - Auth - None All other types should raise TypeError. @@ -72,20 +72,15 @@ def test_client_different_auths( httpx_mock.add_callback( check_credentials_callback, - url=f"https://{DEFAULT_API_URL}{AUTH_URL}", + url=f"https://{auth_server}{AUTH_SERVICE_ACCOUNT_URL}", ) httpx_mock.add_callback(check_token_callback, url="https://url") - Client(auth=UsernamePassword(test_username, test_password)).get("https://url") - Client(auth=Token(test_token)).get("https://url") - - # client accepts None auth, but authorization fails - with raises(AssertionError) as excinfo: - Client(auth=None).get("https://url") + Client(account_name=account_name, auth=auth, api_endpoint=server).get("https://url") with raises(TypeError) as excinfo: - Client(auth=lambda r: r).get("https://url") + Client(account_name=account_name, auth=lambda r: r).get("https://url") assert str(excinfo.value).startswith( 'Invalid "auth" argument' @@ -94,8 +89,8 @@ def test_client_different_auths( def test_client_account_id( httpx_mock: HTTPXMock, - test_username: str, - test_password: str, + auth: Auth, + account_name: str, account_id: str, account_id_url: Pattern, account_id_callback: Callable, @@ -107,7 +102,8 @@ def test_client_account_id( httpx_mock.add_callback(auth_callback, url=auth_url) with Client( - auth=UsernamePassword(test_username, test_password), + account_name=account_name, + auth=auth, base_url=fix_url_schema(settings.server), api_endpoint=settings.server, ) as c: @@ -118,19 +114,21 @@ def test_client_account_id( def test_refresh_with_hooks( fs: FakeFilesystem, httpx_mock: HTTPXMock, - test_username: str, - test_password: str, - test_token: str, + account_name: str, + client_id: str, + client_secret: str, + access_token: str, ) -> None: """ When hooks are used, the invalid token, fetched from cache, is refreshed """ - tss = TokenSecureStorage(test_username, test_password) - tss.cache_token(test_token, 2**32) + tss = TokenSecureStorage(client_id, client_secret) + tss.cache_token(access_token, 2**32) client = Client( - auth=UsernamePassword(test_username, test_password), + account_name=account_name, + auth=ClientCredentials(client_id, client_secret), event_hooks={ "response": [raise_on_4xx_5xx], }, @@ -144,7 +142,7 @@ def test_refresh_with_hooks( # auth get another token httpx_mock.add_response( status_code=codes.OK, - json={"expires_in": 2**30, "access_token": test_token}, + json={"expires_in": 2**30, "access_token": access_token}, ) # client request success this time diff --git a/tests/unit/client/test_client_async.py b/tests/unit/client/test_client_async.py index 81d513727fc..acad0c4a286 100644 --- a/tests/unit/client/test_client_async.py +++ b/tests/unit/client/test_client_async.py @@ -5,31 +5,28 @@ from pytest import raises from pytest_httpx import HTTPXMock -from firebolt.client import DEFAULT_API_URL, AsyncClient -from firebolt.client.auth import Token, UsernamePassword +from firebolt.client import AsyncClient +from firebolt.client.auth import Auth from firebolt.common import Settings -from firebolt.utils.urls import AUTH_URL +from firebolt.utils.urls import AUTH_SERVICE_ACCOUNT_URL from firebolt.utils.util import fix_url_schema async def test_client_retry( httpx_mock: HTTPXMock, - test_username: str, - test_password: str, - test_token: str, + auth: Auth, + account_name: str, + access_token: str, ): """ Client retries with new auth token if first attempt fails with unauthorized error. """ - async with AsyncClient( - auth=UsernamePassword(test_username, test_password) - ) as client: - + async with AsyncClient(account_name=account_name, auth=auth) as client: # auth get token httpx_mock.add_response( status_code=codes.OK, - json={"expires_in": 2**30, "access_token": test_token}, + json={"expires_in": 2**30, "access_token": access_token}, ) # client request failed authorization @@ -40,7 +37,7 @@ async def test_client_retry( # auth get another token httpx_mock.add_response( status_code=codes.OK, - json={"expires_in": 2**30, "access_token": test_token}, + json={"expires_in": 2**30, "access_token": access_token}, ) # client request success this time @@ -57,9 +54,10 @@ async def test_client_different_auths( httpx_mock: HTTPXMock, check_credentials_callback: Callable, check_token_callback: Callable, - test_username: str, - test_password: str, - test_token: str, + auth: Auth, + account_name: str, + auth_server: str, + server: str, ): """ Client properly handles such auth types: @@ -71,25 +69,20 @@ async def test_client_different_auths( httpx_mock.add_callback( check_credentials_callback, - url=f"https://{DEFAULT_API_URL}{AUTH_URL}", + url=f"https://{auth_server}{AUTH_SERVICE_ACCOUNT_URL}", ) httpx_mock.add_callback(check_token_callback, url="https://url") async with AsyncClient( - auth=UsernamePassword(test_username, test_password) + account_name=account_name, auth=auth, api_endpoint=server ) as client: await client.get("https://url") - async with AsyncClient(auth=Token(test_token)) as client: - await client.get("https://url") - - # client accepts None auth, but authorization fails - with raises(AssertionError) as excinfo: - async with AsyncClient(auth=None) as client: - await client.get("https://url") with raises(TypeError) as excinfo: - async with AsyncClient(auth=lambda r: r): + async with AsyncClient( + account_name=account_name, auth=lambda r: r, api_endpoint=server + ): await client.get("https://url") assert str(excinfo.value).startswith( @@ -99,8 +92,8 @@ async def test_client_different_auths( async def test_client_account_id( httpx_mock: HTTPXMock, - test_username: str, - test_password: str, + auth: Auth, + account_name: str, account_id: str, account_id_url: Pattern, account_id_callback: Callable, @@ -112,7 +105,8 @@ async def test_client_account_id( httpx_mock.add_callback(auth_callback, url=auth_url) async with AsyncClient( - auth=UsernamePassword(test_username, test_password), + account_name=account_name, + auth=auth, base_url=fix_url_schema(settings.server), api_endpoint=settings.server, ) as c: diff --git a/tests/unit/common/test_settings.py b/tests/unit/common/test_settings.py index cd4c23c4c14..084046ff2bf 100644 --- a/tests/unit/common/test_settings.py +++ b/tests/unit/common/test_settings.py @@ -1,22 +1,12 @@ import os -from typing import Tuple from unittest.mock import Mock, patch -from pytest import mark, raises - from firebolt.client.auth import Auth from firebolt.common.settings import Settings -@mark.parametrize( - "fields", - ( - ("user", "password", "account_name", "server", "default_region"), - ("access_token", "account_name", "server", "default_region"), - ("auth", "account_name", "server", "default_region"), - ), -) -def test_settings_happy_path(fields: Tuple[str]) -> None: +def test_settings_happy_path() -> None: + fields = ("auth", "account_name", "server", "default_region") kwargs = {f: (f if f != "auth" else Auth()) for f in fields} s = Settings(**kwargs) @@ -27,30 +17,13 @@ def test_settings_happy_path(fields: Tuple[str]) -> None: ), f"Invalid settings value {f}" -creds_fields = ("access_token", "user", "password") -other_fields = ("server", "default_region") - - -@mark.parametrize( - "kwargs", - ( - {f: f for f in other_fields}, - {f: f for f in creds_fields + other_fields}, - {"auth": Auth(), "access_token": "123", **{f: f for f in other_fields}}, - ), -) -def test_settings_auth_credentials(kwargs) -> None: - with raises(ValueError) as exc_info: - Settings(**kwargs) - - @patch("firebolt.common.settings.logger") def test_no_deprecation_warning_with_env(logger_mock: Mock): with patch.dict( os.environ, { - "FIREBOLT_USER": "user", - "FIREBOLT_PASSWORD": "password", + "FIREBOLT_CLIENT_ID": "client_id", + "FIREBOLT_CLIENT_SECRET": "client_secret", "FIREBOLT_SERVER": "dummy.firebolt.io", }, clear=True, @@ -59,5 +32,7 @@ def test_no_deprecation_warning_with_env(logger_mock: Mock): logger_mock.warning.assert_not_called() assert s.server == "dummy.firebolt.io" assert s.auth is not None, "Settings.auth wasn't populated from env variables" - assert s.auth.username == "user", "Invalid username in Settings.auth" - assert s.auth.password == "password", "Invalid password in Settings.auth" + assert s.auth.client_id == "client_id", "Invalid username in Settings.auth" + assert ( + s.auth.client_secret == "client_secret" + ), "Invalid password in Settings.auth" diff --git a/tests/unit/common/test_util.py b/tests/unit/common/test_util.py deleted file mode 100644 index 1b02b7b3dc9..00000000000 --- a/tests/unit/common/test_util.py +++ /dev/null @@ -1,68 +0,0 @@ -from asyncio import run -from threading import Thread - -from pytest import raises - -from firebolt.utils.util import async_to_sync - - -def test_async_to_sync_happy_path(): - """async_to_sync properly converts coroutine to sync function.""" - - class JobMarker(Exception): - pass - - async def task(): - raise JobMarker() - - for i in range(3): - with raises(JobMarker): - async_to_sync(task)() - - -def test_async_to_sync_thread(): - """async_to_sync works correctly in threads.""" - - marks = [False] * 3 - - async def task(id: int): - marks[id] = True - - ts = [Thread(target=async_to_sync(task), args=[i]) for i in range(3)] - [t.start() for t in ts] - [t.join() for t in ts] - assert all(marks) - - -def test_async_to_sync_after_run(): - """async_to_sync runs correctly after asyncio.run.""" - - class JobMarker(Exception): - pass - - async def task(): - raise JobMarker() - - with raises(JobMarker): - run(task()) - - # Here local event loop is closed by run - - with raises(JobMarker): - async_to_sync(task)() - - -async def test_nested_loops() -> None: - """async_to_sync works correctly inside a running loop.""" - - class JobMarker(Exception): - pass - - async def task(): - raise JobMarker() - - with raises(JobMarker): - await task() - - with raises(JobMarker): - async_to_sync(task)() diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 1b9e17fde8b..9910471dc29 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,4 +1,3 @@ -from json import loads from re import Pattern, compile from typing import Callable, List @@ -7,7 +6,7 @@ from pyfakefs.fake_filesystem_unittest import Patcher from pytest import fixture -from firebolt.client.auth import Auth, UsernamePassword +from firebolt.client.auth import Auth, ClientCredentials from firebolt.common.settings import Settings from firebolt.model.provider import Provider from firebolt.model.region import Region, RegionKey @@ -28,9 +27,7 @@ ACCOUNT_BY_NAME_URL, ACCOUNT_DATABASE_BY_NAME_URL, ACCOUNT_ENGINE_URL, - ACCOUNT_ENGINE_URL_BY_DATABASE_NAME, - ACCOUNT_URL, - AUTH_URL, + AUTH_SERVICE_ACCOUNT_URL, DATABASES_URL, ENGINES_URL, ) @@ -52,13 +49,13 @@ def global_fake_fs(request) -> None: @fixture -def username() -> str: - return "email@domain.com" +def client_id() -> str: + return "client_id" @fixture -def password() -> str: - return "*****" +def client_secret() -> str: + return "client_secret" @fixture @@ -66,16 +63,31 @@ def server() -> str: return "api-dev.mock.firebolt.io" +@fixture +def auth_server() -> str: + return "id.mock.firebolt.io" + + @fixture def account_id() -> str: return "mock_account_id" +@fixture +def account_name() -> str: + return "mock_account_name" + + @fixture def access_token() -> str: return "mock_access_token" +@fixture +def access_token_2() -> str: + return "mock_access_token_2" + + @fixture def provider() -> Provider: return Provider( @@ -117,21 +129,20 @@ def mock_regions(region_1, region_2) -> List[Region]: @fixture -def settings(server: str, region_1: str, username: str, password: str) -> Settings: +def auth(client_id: str, client_secret: str) -> Auth: + return ClientCredentials(client_id, client_secret) + + +@fixture +def settings(server: str, region_1: str, auth: Auth, account_name: str) -> Settings: return Settings( server=server, - user=username, - password=password, + auth=auth, default_region=region_1.name, - account_name=None, + account_name=account_name, ) -@fixture -def auth(username: str, password: str) -> Auth: - return UsernamePassword(username, password) - - @fixture def auth_callback(auth_url: str) -> Callable: def do_mock( @@ -148,8 +159,8 @@ def do_mock( @fixture -def auth_url(settings: Settings) -> str: - return f"https://{settings.server}{AUTH_URL}" +def auth_url(auth_server: str) -> str: + return f"https://{auth_server}{AUTH_SERVICE_ACCOUNT_URL}" @fixture @@ -163,12 +174,12 @@ def db_description() -> str: @fixture -def account_id_url(settings: Settings) -> Pattern: - base = f"https://{settings.server}{ACCOUNT_BY_NAME_URL}?account_name=" - default_base = f"https://{settings.server}{ACCOUNT_URL}" +def account_id_url(settings: Settings, account_name: str) -> Pattern: + account_name_re = r"[^\\\\]*" + base = f"https://{settings.server}{ACCOUNT_BY_NAME_URL}" base = base.replace("/", "\\/").replace("?", "\\?") - default_base = default_base.replace("/", "\\/").replace("?", "\\?") - return compile(f"(?:{base}.*|{default_base})") + base = base.format(account_name=account_name_re) + return compile(base) @fixture @@ -180,14 +191,9 @@ def do_mock( request: Request, **kwargs, ) -> Response: - if "account_name" not in request.url.params: - return Response( - status_code=httpx.codes.OK, json={"account": {"id": account_id}} - ) - # In this case, an account_name *should* be specified. - if request.url.params["account_name"] != settings.account_name: - raise AccountNotFoundError(request.url.params["account_name"]) - return Response(status_code=httpx.codes.OK, json={"account_id": account_id}) + if request.url.path.split("/")[-2] != settings.account_name: + raise AccountNotFoundError(request.url.path.split("/")[-2]) + return Response(status_code=httpx.codes.OK, json={"id": account_id}) return do_mock @@ -216,48 +222,6 @@ def get_engine_name_by_id_url( ) -@fixture -def get_engine_url_by_id_url( - settings: Settings, account_id: str, engine_id: str -) -> str: - return f"https://{settings.server}" + ACCOUNT_ENGINE_URL.format( - account_id=account_id, engine_id=engine_id - ) - - -@fixture -def get_engine_url_by_id_callback( - get_engine_url_by_id_url: str, engine_id: str, settings: Settings -) -> Callable: - def do_mock( - request: Request = None, - **kwargs, - ) -> Response: - assert request.url == get_engine_url_by_id_url - return Response( - status_code=httpx.codes.OK, - json={ - "engine": { - "name": "name", - "compute_region_id": { - "provider_id": "provider", - "region_id": "region", - }, - "settings": { - "preset": "", - "auto_stop_delay_duration": "1s", - "minimum_logging_level": "", - "is_read_only": False, - "warm_up": "", - }, - "endpoint": f"https://{settings.server}", - } - }, - ) - - return do_mock - - @fixture def get_engines_url(settings: Settings) -> str: return f"https://{settings.server}{ENGINES_URL}" @@ -301,14 +265,6 @@ def do_mock( return do_mock -@fixture -def engine_by_db_url(settings: Settings, account_id: str) -> str: - return ( - f"https://{settings.server}" - f"{ACCOUNT_ENGINE_URL_BY_DATABASE_NAME.format(account_id=account_id)}" - ) - - @fixture def db_api_exceptions(): exceptions = { @@ -343,17 +299,21 @@ def check_token(request: Request = None, **kwargs) -> Response: @fixture -def check_credentials_callback(settings: Settings, access_token: str) -> Callable: +def check_credentials_callback( + client_id: str, client_secret: str, access_token: str +) -> Callable: def check_credentials( - request: Request = None, + request: httpx.Request = None, **kwargs, ) -> Response: assert request, "empty request" - body = loads(request.read()) - assert "username" in body, "Missing username" - assert body["username"] == settings.user, "Invalid username" - assert "password" in body, "Missing password" - assert body["password"] == settings.password, "Invalid password" + body = request.read().decode("utf-8") + assert "client_id" in body, "Missing id" + assert f"client_id={client_id}" in body, "Invalid id" + assert "client_secret" in body, "Missing secret" + assert f"client_secret={client_secret}" in body, "Invalid secret" + assert "grant_type" in body, "Missing grant_type" + assert "grant_type=client_credentials" in body, "Invalid grant_type" return Response( status_code=httpx.codes.OK, diff --git a/tests/unit/db/conftest.py b/tests/unit/db/conftest.py index 1c06ed3a253..85b3a1d6082 100644 --- a/tests/unit/db/conftest.py +++ b/tests/unit/db/conftest.py @@ -1,17 +1,27 @@ +from typing import Callable + from pytest import fixture -from firebolt.common.settings import Settings +from firebolt.client.auth import Auth from firebolt.db import Connection, Cursor, connect @fixture -def connection(settings: Settings, db_name: str) -> Connection: +def connection( + server: str, + db_name: str, + auth: Auth, + engine_name: str, + account_name: str, + mock_connection_flow: Callable, +) -> Connection: + mock_connection_flow() with connect( - engine_url=settings.server, + engine_name=engine_name, database=db_name, - username="u", - password="p", - api_endpoint=settings.server, + auth=auth, + account_name=account_name, + api_endpoint=server, ) as connection: yield connection diff --git a/tests/unit/db/test_connection.py b/tests/unit/db/test_connection.py index 571c8d7f06d..89ef9e741cb 100644 --- a/tests/unit/db/test_connection.py +++ b/tests/unit/db/test_connection.py @@ -1,25 +1,23 @@ import gc import warnings -from re import Pattern from typing import Callable, List +from unittest.mock import patch -from httpx import codes from pyfakefs.fake_filesystem_unittest import Patcher from pytest import mark, raises, warns from pytest_httpx import HTTPXMock -from firebolt.client.auth import Auth, Token, UsernamePassword +from firebolt.client.auth import Auth, ClientCredentials from firebolt.common._types import ColType from firebolt.common.settings import Settings from firebolt.db import Connection, connect from firebolt.utils.exception import ( - AccountNotFoundError, ConfigurationError, ConnectionClosedError, - FireboltEngineError, + EngineNotRunningError, + InterfaceError, ) from firebolt.utils.token_storage import TokenSecureStorage -from firebolt.utils.urls import ACCOUNT_ENGINE_ID_BY_NAME_URL def test_closed_connection(connection: Connection) -> None: @@ -52,198 +50,132 @@ def test_cursors_closed_on_close(connection: Connection) -> None: def test_cursor_initialized( settings: Settings, - db_name: str, - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, - query_url: str, + mock_query: Callable, + connection: Connection, python_query_data: List[List[ColType]], ) -> None: """Connection initialized its cursors properly.""" - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() - for url in (settings.server, f"https://{settings.server}"): - with connect( - engine_url=url, - database=db_name, - username="u", - password="p", - api_endpoint=settings.server, - ) as connection: + cursor = connection.cursor() + assert cursor.connection == connection, "Invalid cursor connection attribute" + assert cursor._client == connection._client, "Invalid cursor _client attribute" - cursor = connection.cursor() - assert ( - cursor.connection == connection - ), "Invalid cursor connection attribute" - assert ( - cursor._client == connection._client - ), "Invalid cursor _client attribute" + assert cursor.execute("select*") == len(python_query_data) - assert cursor.execute("select*") == len(python_query_data) - - cursor.close() - assert ( - cursor not in connection._cursors - ), "Cursor wasn't removed from connection after close" + cursor.close() + assert ( + cursor not in connection._cursors + ), "Cursor wasn't removed from connection after close" def test_connect_empty_parameters(): with raises(ConfigurationError): - with connect(engine_url="engine_url"): - pass - - -def test_connect_access_token( - settings: Settings, - db_name: str, - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - check_token_callback: Callable, - query_url: str, - python_query_data: List[List[ColType]], - access_token: str, -): - httpx_mock.add_callback(check_token_callback, url=query_url) - with ( - connect( - engine_url=settings.server, - database=db_name, - access_token=access_token, - api_endpoint=settings.server, - ) - ) as connection: - cursor = connection.cursor() - assert cursor.execute("select*") == -1 - - with raises(ConfigurationError): - with connect(engine_url="engine_url", database="database"): - pass - - with raises(ConfigurationError): - with connect( - engine_url="engine_url", - database="database", - username="username", - password="password", - access_token="access_token", - ): + with connect(engine_name="engine_name"): pass def test_connect_engine_name( - settings: Settings, db_name: str, + account_name: str, + engine_name: str, + auth: Auth, + server: str, + python_query_data: List[List[ColType]], + mock_query: Callable, httpx_mock: HTTPXMock, - auth_callback: Callable, + system_engine_query_url: str, + get_engine_url_not_running_callback: Callable, + get_engine_url_invalid_db_callback: Callable, auth_url: str, - query_callback: Callable, - query_url: str, - account_id_url: Pattern, + check_credentials_callback: Callable, + get_system_engine_url: str, + get_system_engine_callback: Callable, + get_engine_url_callback: Callable, + account_id_url: str, account_id_callback: Callable, - engine_id: str, - get_engine_url_by_id_url: str, - get_engine_url_by_id_callback: Callable, - python_query_data: List[List[ColType]], - account_id: str, ): """connect properly handles engine_name""" - - with raises(ConfigurationError): - connect( - engine_url="engine_url", - engine_name="engine_name", - database="db", - username="username", - password="password", - ) - - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(get_system_engine_callback, url=get_system_engine_url) httpx_mock.add_callback(account_id_callback, url=account_id_url) - httpx_mock.add_callback(get_engine_url_by_id_callback, url=get_engine_url_by_id_url) - - engine_name = settings.server.split(".")[0] - - # Mock engine id lookup error - httpx_mock.add_response( - url=f"https://{settings.server}" - + ACCOUNT_ENGINE_ID_BY_NAME_URL.format(account_id=account_id) - + f"?engine_name={engine_name}", - status_code=codes.NOT_FOUND, - ) - - with raises(FireboltEngineError): - connect( - database="db", - username="username", - password="password", - engine_name=engine_name, - account_name=settings.account_name, - api_endpoint=settings.server, - ) - # Mock engine id lookup by name - httpx_mock.add_response( - url=f"https://{settings.server}" - + ACCOUNT_ENGINE_ID_BY_NAME_URL.format(account_id=account_id) - + f"?engine_name={engine_name}", - status_code=codes.OK, - json={"engine_id": {"engine_id": engine_id}}, - ) + mock_query() + + for callback, err_cls in ( + (get_engine_url_invalid_db_callback, InterfaceError), + (get_engine_url_not_running_callback, EngineNotRunningError), + ): + httpx_mock.add_callback(callback, url=system_engine_query_url) + with raises(err_cls): + c = connect( + database=db_name, + auth=auth, + engine_name=engine_name, + account_name=account_name, + api_endpoint=server, + ) + print(type(c)) + with c: + pass + + httpx_mock.add_callback(get_engine_url_callback, url=system_engine_query_url) with connect( engine_name=engine_name, database=db_name, - username="u", - password="p", - account_name=settings.account_name, - api_endpoint=settings.server, + auth=auth, + account_name=account_name, + api_endpoint=server, ) as connection: assert connection.cursor().execute("select*") == len(python_query_data) -def test_connect_default_engine( - settings: Settings, +def test_connect_database( db_name: str, - httpx_mock: HTTPXMock, - auth_callback: Callable, auth_url: str, - query_callback: Callable, - query_url: str, - account_id_url: Pattern, - account_id_callback: Callable, - database_id: str, - engine_by_db_url: str, + server: str, + auth: Auth, + account_name: str, python_query_data: List[List[ColType]], - account_id: str, + httpx_mock: HTTPXMock, + query_callback: str, + check_credentials_callback: Callable, + system_engine_query_url: str, + system_engine_no_db_query_url: str, + get_system_engine_url: str, + get_system_engine_callback: Callable, + account_id_url: str, + account_id_callback: Callable, ): - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(get_system_engine_callback, url=get_system_engine_url) + httpx_mock.add_callback(query_callback, url=system_engine_no_db_query_url) httpx_mock.add_callback(account_id_callback, url=account_id_url) - engine_by_db_url = f"{engine_by_db_url}?database_name={db_name}" - - httpx_mock.add_response( - url=engine_by_db_url, - status_code=codes.OK, - json={ - "engine_url": settings.server, - }, - ) + with connect( + database=None, + auth=auth, + account_name=account_name, + api_endpoint=server, + ) as connection: + connection.cursor().execute("select*") + + httpx_mock.reset(True) + httpx_mock.add_callback(get_system_engine_callback, url=get_system_engine_url) + httpx_mock.add_callback(query_callback, url=system_engine_query_url) + httpx_mock.add_callback(account_id_callback, url=account_id_url) + with connect( database=db_name, - username="u", - password="p", - account_name=settings.account_name, - api_endpoint=settings.server, + auth=auth, + account_name=account_name, + api_endpoint=server, ) as connection: assert connection.cursor().execute("select*") == len(python_query_data) -def test_connection_unclosed_warnings(): - c = Connection("", "", None, "") +def test_connection_unclosed_warnings(auth: Auth): + c = Connection("", "", auth, "", None) with warns(UserWarning) as winfo: del c gc.collect() @@ -253,8 +185,8 @@ def test_connection_unclosed_warnings(): ), "Invalid unclosed connection warning" -def test_connection_no_warnings(): - c = Connection("", "", None, "") +def test_connection_no_warnings(auth: Auth): + c = Connection("", "", auth, "", None) c.close() with warnings.catch_warnings(): warnings.simplefilter("error") @@ -273,110 +205,109 @@ def test_connection_commit(connection: Connection): @mark.nofakefs def test_connection_token_caching( - settings: Settings, db_name: str, - httpx_mock: HTTPXMock, - check_credentials_callback: Callable, - auth_url: str, - query_callback: Callable, - query_url: str, - python_query_data: List[List[ColType]], + server: str, access_token: str, + client_id: str, + client_secret: str, + engine_name: str, + account_name: str, + python_query_data: List[List[ColType]], + mock_connection_flow: Callable, + mock_query: Callable, ) -> None: - httpx_mock.add_callback(check_credentials_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_connection_flow() + mock_query() with Patcher(): with connect( database=db_name, - username=settings.user, - password=settings.password, - engine_url=settings.server, - account_name=settings.account_name, - api_endpoint=settings.server, - use_token_cache=True, + auth=ClientCredentials(client_id, client_secret, use_token_cache=True), + engine_name=engine_name, + account_name=account_name, + api_endpoint=server, ) as connection: assert connection.cursor().execute("select*") == len(python_query_data) - ts = TokenSecureStorage(username=settings.user, password=settings.password) + ts = TokenSecureStorage(username=client_id, password=client_secret) assert ts.get_cached_token() == access_token, "Invalid token value cached" # Do the same, but with use_token_cache=False with Patcher(): with connect( database=db_name, - username=settings.user, - password=settings.password, - engine_url=settings.server, - account_name=settings.account_name, - api_endpoint=settings.server, - use_token_cache=False, + auth=ClientCredentials(client_id, client_secret, use_token_cache=False), + engine_name=engine_name, + account_name=account_name, + api_endpoint=server, ) as connection: assert connection.cursor().execute("select*") == len(python_query_data) - ts = TokenSecureStorage(username=settings.user, password=settings.password) + ts = TokenSecureStorage(username=client_id, password=client_secret) assert ( ts.get_cached_token() is None ), "Token is cached even though caching is disabled" -def test_connect_with_auth( - httpx_mock: HTTPXMock, - settings: Settings, +def test_connect_with_user_agent( + engine_name: str, + account_name: str, + server: str, db_name: str, - check_credentials_callback: Callable, - auth_url: str, + auth: Auth, + access_token: str, + httpx_mock: HTTPXMock, query_callback: Callable, query_url: str, - access_token: str, + mock_connection_flow: Callable, ) -> None: - httpx_mock.add_callback(check_credentials_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) - - for auth in ( - UsernamePassword( - settings.user, - settings.password, - use_token_cache=False, - ), - Token(access_token), - ): + with patch("firebolt.db.connection.get_user_agent_header") as ut: + ut.return_value = "MyConnector/1.0 DriverA/1.1" + mock_connection_flow() + httpx_mock.add_callback( + query_callback, + url=query_url, + match_headers={"User-Agent": "MyConnector/1.0 DriverA/1.1"}, + ) + with connect( auth=auth, database=db_name, - engine_url=settings.server, - account_name=settings.account_name, - api_endpoint=settings.server, + engine_name=engine_name, + account_name=account_name, + api_endpoint=server, + additional_parameters={ + "user_clients": [("MyConnector", "1.0")], + "user_drivers": [("DriverA", "1.1")], + }, ) as connection: connection.cursor().execute("select*") + ut.assert_called_with([("DriverA", "1.1")], [("MyConnector", "1.0")]) -def test_connect_account_name( - httpx_mock: HTTPXMock, - auth: Auth, - settings: Settings, +def test_connect_no_user_agent( + engine_name: str, + account_name: str, + server: str, db_name: str, - auth_url: str, - check_credentials_callback: Callable, - account_id_url: Pattern, - account_id_callback: Callable, -): - httpx_mock.add_callback(check_credentials_callback, url=auth_url) - httpx_mock.add_callback(account_id_callback, url=account_id_url) + auth: Auth, + access_token: str, + httpx_mock: HTTPXMock, + query_callback: Callable, + query_url: str, + mock_connection_flow: Callable, +) -> None: + with patch("firebolt.db.connection.get_user_agent_header") as ut: + ut.return_value = "Python/3.0" + mock_connection_flow() + httpx_mock.add_callback( + query_callback, url=query_url, match_headers={"User-Agent": "Python/3.0"} + ) - with raises(AccountNotFoundError): with connect( auth=auth, database=db_name, - engine_url=settings.server, - account_name="invalid", - api_endpoint=settings.server, - ): - pass - - with connect( - auth=auth, - database=db_name, - engine_url=settings.server, - account_name=settings.account_name, - api_endpoint=settings.server, - ): - pass + engine_name=engine_name, + account_name=account_name, + api_endpoint=server, + ) as connection: + connection.cursor().execute("select*") + ut.assert_called_with([], []) diff --git a/tests/unit/db/test_cursor.py b/tests/unit/db/test_cursor.py index 8b8e18be77c..04f9082012b 100644 --- a/tests/unit/db/test_cursor.py +++ b/tests/unit/db/test_cursor.py @@ -18,15 +18,12 @@ def test_cursor_state( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, + mock_query: Callable, query_url: str, cursor: Cursor, ): """Cursor state changes depending on the operations performed with it.""" - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() assert cursor._state == CursorState.NONE @@ -89,10 +86,7 @@ def test_closed_cursor(cursor: Cursor): def test_cursor_no_query( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, - query_url: str, + mock_query: Callable, cursor: Cursor, ): """Some of cursor methods are unavailable until a query is run.""" @@ -103,8 +97,7 @@ def test_cursor_no_query( "nextset", ) - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() for method in methods: with raises(QueryNotRunError): @@ -134,12 +127,8 @@ def test_cursor_no_query( def test_cursor_execute( - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, - insert_query_callback: Callable, - query_url: str, + mock_query: Callable, + mock_insert_query: Callable, cursor: Cursor, python_query_description: List[Column], python_query_data: List[List[ColType]], @@ -157,8 +146,7 @@ def test_cursor_execute( ), ): # Query with json output - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() assert query() == len( python_query_data ), f"Invalid row count returned for {message}." @@ -181,11 +169,8 @@ def test_cursor_execute( cursor.fetchone() is None ), f"Non-empty fetchone after all data received for {message}." - httpx_mock.reset(True) - # Query with empty output - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(insert_query_callback, url=query_url) + mock_insert_query() assert query() == -1, f"Invalid row count for insert using {message}." assert ( cursor.rowcount == -1 @@ -197,8 +182,6 @@ def test_cursor_execute( def test_cursor_execute_error( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, query_url: str, cursor: Cursor, ): @@ -213,8 +196,6 @@ def test_cursor_execute_error( "server-side synchronous executemany()", ), ): - httpx_mock.add_callback(auth_callback, url=auth_url) - # Internal httpx error def http_error(*args, **kwargs): raise StreamError("httpx error") @@ -254,17 +235,12 @@ def http_error(*args, **kwargs): def test_cursor_fetchone( - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, - insert_query_callback: Callable, - query_url: str, + mock_query: Callable, + mock_insert_query: Callable, cursor: Cursor, ): """cursor fetchone fetches single row in correct order; if no rows returns None.""" - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() cursor.execute("sql") @@ -279,27 +255,22 @@ def test_cursor_fetchone( cursor.fetchone() is None ), "fetchone should return None when no rows left to fetch" - httpx_mock.add_callback(insert_query_callback, url=query_url) + mock_insert_query() cursor.execute("sql") with raises(DataError): cursor.fetchone() def test_cursor_fetchmany( - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, - insert_query_callback: Callable, - query_url: str, + mock_query: Callable, + mock_insert_query: Callable, cursor: Cursor, ): """ Cursor's fetchmany fetches the provided amount of rows, or arraysize by default. If not enough rows left, returns less or None if there are no rows. """ - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() cursor.execute("sql") @@ -340,24 +311,19 @@ def test_cursor_fetchmany( len(cursor.fetchmany()) == 0 ), "fetchmany should return empty result set when no rows left to fetch" - httpx_mock.add_callback(insert_query_callback, url=query_url) + mock_insert_query() cursor.execute("sql") with raises(DataError): cursor.fetchmany() def test_cursor_fetchall( - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, - insert_query_callback: Callable, - query_url: str, + mock_query: Callable, + mock_insert_query: Callable, cursor: Cursor, ): """cursor fetchall fetches all rows that left after last query.""" - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() cursor.execute("sql") @@ -374,28 +340,23 @@ def test_cursor_fetchall( len(cursor.fetchall()) == 0 ), "fetchmany should return empty result set when no rows left to fetch" - httpx_mock.add_callback(insert_query_callback, url=query_url) + mock_insert_query() cursor.execute("sql") with raises(DataError): cursor.fetchall() def test_cursor_multi_statement( - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, - insert_query_callback: Callable, - query_url: str, + mock_query: Callable, + mock_insert_query: Callable, cursor: Cursor, python_query_description: List[Column], python_query_data: List[List[ColType]], ): """executemany with multiple parameter sets is not supported.""" - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) - httpx_mock.add_callback(insert_query_callback, url=query_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() + mock_insert_query() + mock_query() rc = cursor.execute("select * from t; insert into t values (1, 2); select * from t") assert rc == len(python_query_data), "Invalid row count returned" @@ -432,17 +393,11 @@ def test_cursor_multi_statement( def test_cursor_set_statements( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, select_one_query_callback: Callable, set_query_url: str, cursor: Cursor, - python_query_description: List[Column], - python_query_data: List[List[ColType]], ): """cursor correctly parses and processes set statements.""" - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback(select_one_query_callback, url=f"{set_query_url}&a=b") assert len(cursor._set_parameters) == 0 @@ -504,8 +459,6 @@ def test_cursor_set_statements( def test_cursor_set_parameters_sent( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, set_query_url: str, query_url: str, query_with_params_callback: Callable, @@ -514,8 +467,6 @@ def test_cursor_set_parameters_sent( set_params: Dict, ): """Cursor passes provided set parameters to engine.""" - httpx_mock.add_callback(auth_callback, url=auth_url) - params = "" for p, v in set_params.items(): @@ -531,16 +482,11 @@ def test_cursor_set_parameters_sent( def test_cursor_skip_parse( - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_url: str, - query_callback: Callable, + mock_query: Callable, cursor: Cursor, ): """Cursor doesn't process a query if skip_parsing is provided.""" - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() with patch("firebolt.db.cursor.split_format_sql") as split_format_sql_mock: cursor.execute("non-an-actual-sql") @@ -553,8 +499,6 @@ def test_cursor_skip_parse( def test_cursor_server_side_async_execute( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, server_side_async_id_callback: Callable, server_side_async_id: Callable, query_with_params_url: str, @@ -576,7 +520,6 @@ def test_cursor_server_side_async_execute( ): # Query with json output - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback( server_side_async_id_callback, url=query_with_params_url ) @@ -594,8 +537,6 @@ def test_cursor_server_side_async_execute( def test_cursor_server_side_async_cancel( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, server_side_async_cancel_callback: Callable, server_side_async_id: Callable, query_with_params_url: str, @@ -607,7 +548,6 @@ def test_cursor_server_side_async_cancel( """ # Query with json output - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback( server_side_async_cancel_callback, url=query_with_params_url ) @@ -616,8 +556,6 @@ def test_cursor_server_side_async_cancel( def test_cursor_server_side_async_get_status_completed( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, server_side_async_get_status_callback: Callable, server_side_async_id: Callable, query_with_params_url: str, @@ -629,7 +567,6 @@ def test_cursor_server_side_async_get_status_completed( """ # Query with json output - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback( server_side_async_get_status_callback, url=query_with_params_url ) @@ -639,8 +576,6 @@ def test_cursor_server_side_async_get_status_completed( def test_cursor_server_side_async_get_status_not_yet_available( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, server_side_async_get_status_not_yet_availabe_callback: Callable, server_side_async_id: Callable, query_with_params_url: str, @@ -652,7 +587,6 @@ def test_cursor_server_side_async_get_status_not_yet_available( """ # Query with json output - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback( server_side_async_get_status_not_yet_availabe_callback, url=query_with_params_url, @@ -663,15 +597,12 @@ def test_cursor_server_side_async_get_status_not_yet_available( def test_cursor_server_side_async_get_status_error( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, server_side_async_get_status_error: Callable, server_side_async_id: Callable, query_with_params_url: str, cursor: Cursor, ): """ """ - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback( server_side_async_get_status_error, url=query_with_params_url ) diff --git a/tests/unit/db_conftest.py b/tests/unit/db_conftest.py index b9dccba3cf8..c5a740ac0d9 100644 --- a/tests/unit/db_conftest.py +++ b/tests/unit/db_conftest.py @@ -5,10 +5,12 @@ from httpx import URL, Request, Response, codes from pytest import fixture +from pytest_httpx import HTTPXMock from firebolt.async_db.cursor import JSON_OUTPUT_FORMAT, ColType, Column from firebolt.common.settings import Settings from firebolt.db import ARRAY, DECIMAL +from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME QUERY_ROW_COUNT: int = 10 @@ -332,8 +334,8 @@ def set_params() -> Dict: @fixture def query_url(settings: Settings, db_name: str) -> str: return URL( - f"https://{settings.server}/?database={db_name}" - f"&output_format={JSON_OUTPUT_FORMAT}" + f"https://{settings.server}/", + params={"output_format": JSON_OUTPUT_FORMAT, "database": db_name}, ) @@ -346,3 +348,170 @@ def set_query_url(settings: Settings, db_name: str) -> str: def query_with_params_url(query_url: str, set_params: str) -> str: params_encoded = "&".join([f"{k}={encode_param(v)}" for k, v in set_params.items()]) query_url = f"{query_url}&{params_encoded}" + + +def _get_engine_url_callback(server: str, db_name: str, status="Running") -> Callable: + def do_query(request: Request, **kwargs) -> Response: + set_parameters = request.url.params + assert ( + len(set_parameters) == 3 + and "output_format" in set_parameters + and "database" in set_parameters + and "account_id" in set_parameters + ) + data = [[server, db_name, status]] + query_response = { + "meta": [{"name": "name", "type": "Text"} for _ in range(len(data[0]))], + "data": data, + "rows": len(data), + # Real example of statistics field value, not used by our code + "statistics": { + "elapsed": 0.002983335, + "time_before_execution": 0.002729331, + "time_to_execute": 0.000215215, + "rows_read": 1, + "bytes_read": 1, + "scanned_bytes_cache": 0, + "scanned_bytes_storage": 0, + }, + } + return Response(status_code=codes.OK, json=query_response) + + return do_query + + +@fixture +def get_engine_url_callback(server: str, db_name: str, status="Running") -> Callable: + return _get_engine_url_callback(server, db_name) + + +@fixture +def get_engine_url_not_running_callback(engine_name, db_name) -> Callable: + return _get_engine_url_callback(engine_name, db_name, "Stopped") + + +@fixture +def get_engine_url_invalid_db_callback(engine_name, db_name) -> Callable: + return _get_engine_url_callback(engine_name, "not_" + db_name) + + +def _get_default_db_engine_callback(server: str, status="Running") -> Callable: + def do_query(request: Request, **kwargs) -> Response: + set_parameters = request.url.params + assert len(set_parameters) == 1 and "output_format" in set_parameters + data = [[server, status]] + query_response = { + "meta": [{"name": "name", "type": "Text"} for _ in range(len(data[0]))], + "data": data, + "rows": len(data), + # Real example of statistics field value, not used by our code + "statistics": { + "elapsed": 0.002983335, + "time_before_execution": 0.002729331, + "time_to_execute": 0.000215215, + "rows_read": 1, + "bytes_read": 1, + "scanned_bytes_cache": 0, + "scanned_bytes_storage": 0, + }, + } + return Response(status_code=codes.OK, json=query_response) + + return do_query + + +@fixture +def get_default_db_engine_callback(server: str) -> Callable: + return _get_default_db_engine_callback(server) + + +@fixture +def get_default_db_engine_not_running_callback(server: str) -> Callable: + return _get_default_db_engine_callback(server, "Failed") + + +@fixture +def system_engine_url() -> str: + return "https://bravo.a.eu-west-1.aws.mock.firebolt.io" + + +@fixture +def system_engine_query_url( + system_engine_url: str, db_name: str, account_id: str +) -> str: + return f"{system_engine_url}/dynamic/query?output_format=JSON_Compact&database={db_name}&account_id={account_id}" + + +@fixture +def system_engine_no_db_query_url(system_engine_url: str, account_id: str) -> str: + return f"{system_engine_url}/dynamic/query?output_format=JSON_Compact&account_id={account_id}" + + +@fixture +def get_system_engine_url(server: str, account_name: str) -> str: + return URL( + f"https://{server}" + f"{GATEWAY_HOST_BY_ACCOUNT_NAME.format(account_name=account_name)}" + ) + + +@fixture +def get_system_engine_callback(system_engine_url: str) -> Callable: + def inner( + request: Request = None, + **kwargs, + ) -> Response: + assert request, "empty request" + assert request.method == "GET", "invalid request method" + + return Response( + status_code=codes.OK, + json={"engineUrl": system_engine_url}, + ) + + return inner + + +@fixture +def mock_connection_flow( + httpx_mock: HTTPXMock, + auth_url: str, + check_credentials_callback: Callable, + get_system_engine_url: str, + get_system_engine_callback: Callable, + system_engine_query_url: str, + get_engine_url_callback: Callable, + account_id_url: str, + account_id_callback: Callable, +) -> Callable: + def inner() -> None: + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(get_system_engine_callback, url=get_system_engine_url) + httpx_mock.add_callback(get_engine_url_callback, url=system_engine_query_url) + httpx_mock.add_callback(account_id_callback, url=account_id_url) + + return inner + + +@fixture +def mock_query( + httpx_mock: HTTPXMock, + query_url: str, + query_callback: Callable, +) -> Callable: + def inner() -> None: + httpx_mock.add_callback(query_callback, url=query_url) + + return inner + + +@fixture +def mock_insert_query( + httpx_mock: HTTPXMock, + query_url: str, + insert_query_callback: Callable, +) -> Callable: + def inner() -> None: + httpx_mock.add_callback(insert_query_callback, url=query_url) + + return inner diff --git a/tests/unit/service/conftest.py b/tests/unit/service/conftest.py index 2043f8a4dcf..4cb7a7789a2 100644 --- a/tests/unit/service/conftest.py +++ b/tests/unit/service/conftest.py @@ -290,9 +290,9 @@ def account_engine_url(settings: Settings, account_id, mock_engine) -> str: @fixture -def mock_database(region_1, account_id) -> Database: +def mock_database(region_1: str, account_id: str) -> Database: return Database( - name="mock_db_name", + name="database", description="mock_db_description", compute_region_key=region_1.key, database_key=DatabaseKey( diff --git a/tests/unit/service/test_engine.py b/tests/unit/service/test_engine.py index 18b93af5c44..9686ecbf7ff 100644 --- a/tests/unit/service/test_engine.py +++ b/tests/unit/service/test_engine.py @@ -288,19 +288,19 @@ def test_get_connection( database_url: str, bindings_callback: Callable, bindings_url: str, + mock_connection_flow: Callable, ): - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback(provider_callback, url=provider_url) httpx_mock.add_callback( instance_type_region_1_callback, url=instance_type_region_1_url ) httpx_mock.add_callback(account_id_callback, url=account_id_url) - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback(region_callback, url=region_url) httpx_mock.add_callback(engine_callback, url=engine_url, method="POST") httpx_mock.add_callback(bindings_callback, url=bindings_url) httpx_mock.add_callback(database_callback, url=database_url) + mock_connection_flow() manager = ResourceManager(settings=settings) engine = manager.engines.create(name=engine_name) diff --git a/tests/unit/service/test_resource_manager.py b/tests/unit/service/test_resource_manager.py index 42060b2785e..9ddfeacb468 100644 --- a/tests/unit/service/test_resource_manager.py +++ b/tests/unit/service/test_resource_manager.py @@ -5,7 +5,7 @@ from pytest import mark, raises from pytest_httpx import HTTPXMock -from firebolt.client.auth import Auth, Token, UsernamePassword +from firebolt.client.auth import Auth, ClientCredentials from firebolt.common.settings import Settings from firebolt.service.manager import ResourceManager from firebolt.utils.exception import AccountNotFoundError @@ -35,33 +35,6 @@ def test_rm_credentials( rm = ResourceManager(settings) rm.client.get(url) - token_settings = Settings( - access_token=access_token, - server=settings.server, - default_region=settings.default_region, - ) - - rm = ResourceManager(token_settings) - rm.client.get(url) - - auth_username_password_settings = Settings( - auth=UsernamePassword(settings.user, settings.password), - server=settings.server, - default_region=settings.default_region, - ) - - rm = ResourceManager(auth_username_password_settings) - rm.client.get(url) - - auth_token_settings = Settings( - auth=Token(access_token), - server=settings.server, - default_region=settings.default_region, - ) - - rm = ResourceManager(auth_token_settings) - rm.client.get(url) - @mark.nofakefs def test_rm_token_cache( @@ -70,13 +43,14 @@ def test_rm_token_cache( check_credentials_callback: Callable, settings: Settings, auth_url: str, + account_name: str, account_id_url: Pattern, account_id_callback: Callable, provider_callback: Callable, provider_url: str, access_token: str, ) -> None: - """Credentials, that are passed to rm are processed properly.""" + """Credentials, that are passed to rm are cached properly.""" url = "https://url" httpx_mock.add_callback(check_credentials_callback, url=auth_url) @@ -86,31 +60,37 @@ def test_rm_token_cache( with Patcher(): local_settings = Settings( - user=settings.user, - password=settings.password, + account_name=account_name, + auth=ClientCredentials( + settings.auth.client_id, + settings.auth.client_secret, + use_token_cache=True, + ), server=settings.server, default_region=settings.default_region, - use_token_cache=True, ) rm = ResourceManager(local_settings) rm.client.get(url) - ts = TokenSecureStorage(settings.user, settings.password) + ts = TokenSecureStorage(settings.auth.client_id, settings.auth.client_secret) assert ts.get_cached_token() == access_token, "Invalid token value cached" # Do the same, but with use_token_cache=False with Patcher(): local_settings = Settings( - user=settings.user, - password=settings.password, + account_name=account_name, + auth=ClientCredentials( + settings.auth.client_id, + settings.auth.client_secret, + use_token_cache=False, + ), server=settings.server, default_region=settings.default_region, - use_token_cache=False, ) rm = ResourceManager(local_settings) rm.client.get(url) - ts = TokenSecureStorage(settings.user, settings.password) + ts = TokenSecureStorage(settings.auth.client_id, settings.auth.client_secret) assert ( ts.get_cached_token() is None ), "Token is cached even though caching is disabled" diff --git a/tests/unit/util.py b/tests/unit/util.py index 2f5528a9e6a..98b72304fc1 100644 --- a/tests/unit/util.py +++ b/tests/unit/util.py @@ -3,6 +3,7 @@ from httpx import Request, Response from firebolt.client import AsyncClient, Client +from firebolt.client.auth import Auth from firebolt.model import FireboltBaseModel @@ -15,7 +16,10 @@ def execute_generator_requests( ) -> None: request = next(requests) - with Client(api_endpoint=api_endpoint) as client: + with Client( + account_name="account", auth=Auth(), api_endpoint=api_endpoint + ) as client: + client._auth = None while True: response = client.send(request) try: @@ -30,7 +34,10 @@ async def async_execute_generator_requests( ) -> None: request = await requests.__anext__() - async with AsyncClient(api_endpoint=api_endpoint) as client: + async with AsyncClient( + account_name="account", auth=Auth(), api_endpoint=api_endpoint + ) as client: + client._auth = None while True: response = await client.send(request) try: