diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index dc9db0b55d..26aa5dc02c 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -11,7 +11,10 @@ from firebolt.client.auth import Auth from firebolt.client.client import AsyncClient, AsyncClientV1, AsyncClientV2 from firebolt.common.base_connection import BaseConnection -from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS +from firebolt.common.constants import ( + DEFAULT_TIMEOUT_SECONDS, + ENGINE_STATUS_RUNNING_LIST, +) from firebolt.utils.exception import ( ConfigurationError, ConnectionClosedError, @@ -71,6 +74,7 @@ def __init__( cursor_type: Type[Cursor], system_engine_connection: Optional["Connection"], api_endpoint: str, + init_parameters: Optional[Dict[str, Any]] = None, ): super().__init__() self.api_endpoint = api_endpoint @@ -80,6 +84,7 @@ def __init__( self._cursors: List[Cursor] = [] self._system_engine_connection = system_engine_connection self._client = client + self.init_parameters = init_parameters def cursor(self, **kwargs: Any) -> Cursor: if self.closed: @@ -142,8 +147,8 @@ async def connect( user_agent_header = get_user_agent_header(user_drivers, user_clients) # Use v2 if auth is ClientCredentials # Use v1 if auth is ServiceAccount or UsernamePassword - version = auth.get_firebolt_version() - if version == 2: + auth_version = auth.get_firebolt_version() + if auth_version == 2: assert account_name is not None return await connect_v2( auth=auth, @@ -153,7 +158,7 @@ async def connect( engine_name=engine_name, api_endpoint=api_endpoint, ) - elif version == 1: + elif auth_version == 1: return await connect_v1( auth=auth, user_agent_header=user_agent_header, @@ -223,6 +228,26 @@ async def connect_v2( None, api_endpoint, ) + + account_version = await system_engine_connection._client._account_version + if account_version == 2: + cursor = system_engine_connection.cursor() + if database: + await cursor.execute(f"USE DATABASE {database}") + if engine_name: + await cursor.execute(f"USE ENGINE {engine_name}") + # Ensure cursors created from this conection are using the same starting + # database and engine + return Connection( + cursor.engine_url, + cursor.database, + client, + CursorV2, + system_engine_connection, + api_endpoint, + cursor.parameters, + ) + if not engine_name: return system_engine_connection @@ -237,7 +262,7 @@ async def connect_v2( attached_db, ) = await cursor._get_engine_url_status_db(engine_name) - if status != "Running": + if status not in ENGINE_STATUS_RUNNING_LIST: raise EngineNotRunningError(engine_name) if database is not None and database != attached_db: diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index 04ff37078e..eabb01d55c 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -17,12 +17,9 @@ Union, ) -from httpx import URL -from httpx import AsyncClient as HttpxAsyncClient -from httpx import Response, codes +from httpx import URL, Headers, Response, codes -from firebolt.async_db.util import ENGINE_STATUS_RUNNING -from firebolt.client.client import AsyncClientV1, AsyncClientV2 +from firebolt.client.client import AsyncClient, AsyncClientV1, AsyncClientV2 from firebolt.common._types import ( ColType, Column, @@ -33,14 +30,20 @@ ) from firebolt.common.base_cursor import ( JSON_OUTPUT_FORMAT, + RESET_SESSION_HEADER, + UPDATE_ENDPOINT_HEADER, + UPDATE_PARAMETERS_HEADER, BaseCursor, CursorState, QueryStatus, Statistics, + _parse_update_endpoint, + _parse_update_parameters, _raise_if_internal_set_parameter, check_not_closed, check_query_executed, ) +from firebolt.common.constants import ENGINE_STATUS_RUNNING_LIST from firebolt.utils.exception import ( AsyncExecutionUnavailableError, EngineNotRunningError, @@ -76,23 +79,18 @@ class Cursor(BaseCursor, metaclass=ABCMeta): def __init__( self, *args: Any, - client: HttpxAsyncClient, + client: AsyncClient, connection: Connection, **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) self._client = client self.connection = connection + self.engine_url = connection.engine_url if connection.database: self.database = connection.database - - @property - def database(self) -> Optional[str]: - return self.parameters.get("database") - - @database.setter - def database(self, database: str) -> None: - self.parameters["database"] = database + if connection.init_parameters: + self._update_set_parameters(connection.init_parameters) @abstractmethod async def _api_request( @@ -117,9 +115,9 @@ async def _raise_if_error(self, resp: Response) -> None: if ( resp.status_code == codes.SERVICE_UNAVAILABLE or resp.status_code == codes.NOT_FOUND - ) and not await self.is_engine_running(self.connection.engine_url): + ) and not await self.is_engine_running(self.engine_url): raise EngineNotRunningError( - f"Firebolt engine {self.connection.engine_url} " + f"Firebolt engine {self.engine_url} " "needs to be running to run queries against it." ) _print_error_body(resp) @@ -143,6 +141,30 @@ async def _validate_set_parameter(self, parameter: SetParameter) -> None: # set parameter passed validation self._set_parameters[parameter.name] = parameter.value + async def _parse_response_headers(self, headers: Headers) -> None: + if headers.get(UPDATE_ENDPOINT_HEADER): + endpoint, params = _parse_update_endpoint( + headers.get(UPDATE_ENDPOINT_HEADER) + ) + if ( + params.get("account_id", await self._client.account_id) + != await self._client.account_id + ): + raise OperationalError( + "USE ENGINE command failed. Account parameter mismatch. " + "Contact support" + ) + self._update_set_parameters(params) + self.engine_url = endpoint + self._client.base_url = URL(endpoint) + + if headers.get(RESET_SESSION_HEADER): + self.flush_parameters() + + if headers.get(UPDATE_PARAMETERS_HEADER): + param_dict = _parse_update_parameters(headers.get(UPDATE_PARAMETERS_HEADER)) + self._update_set_parameters(param_dict) + async def _do_execute( self, raw_query: str, @@ -209,8 +231,7 @@ async def _do_execute( query, {"output_format": JSON_OUTPUT_FORMAT} ) await self._raise_if_error(resp) - # get parameters from response - self._parse_response_headers(resp.headers) + await self._parse_response_headers(resp.headers) row_set = self._row_set_from_response(resp) self._append_row_set(row_set) @@ -452,7 +473,8 @@ async def _api_request( parameters = {**(self._set_parameters or {}), **parameters} if self.parameters: parameters = {**self.parameters, **parameters} - if self.connection._is_system: + # Engines v2 always require account_id + if self.connection._is_system or (await self._client._account_version) == 2: assert isinstance(self._client, AsyncClientV2) parameters["account_id"] = await self._client.account_id return await self._client.request( @@ -495,7 +517,6 @@ async def is_engine_running(self, engine_url: str) -> bool: # System engine is always running return True - engine_name = URL(engine_url).host.split(".")[0].replace("-", "_") assert self.connection._system_engine_connection is not None # Type check system_cursor = self.connection._system_engine_connection.cursor() assert isinstance(system_cursor, CursorV2) # Type check, should always be true @@ -503,8 +524,8 @@ async def is_engine_running(self, engine_url: str) -> bool: _, status, _, - ) = await system_cursor._get_engine_url_status_db(engine_name) - return status == ENGINE_STATUS_RUNNING + ) = await system_cursor._get_engine_url_status_db(self.engine_name) + return status in ENGINE_STATUS_RUNNING_LIST async def _get_engine_url_status_db(self, engine_name: str) -> Tuple[str, str, str]: await self.execute( diff --git a/src/firebolt/async_db/util.py b/src/firebolt/async_db/util.py index e497a3b66e..1f051720a5 100644 --- a/src/firebolt/async_db/util.py +++ b/src/firebolt/async_db/util.py @@ -11,8 +11,6 @@ ) from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME -ENGINE_STATUS_RUNNING = "Running" - async def _get_system_engine_url( auth: Auth, diff --git a/src/firebolt/client/constants.py b/src/firebolt/client/constants.py index 6bbc2c91e7..78d23c1bfa 100644 --- a/src/firebolt/client/constants.py +++ b/src/firebolt/client/constants.py @@ -5,7 +5,7 @@ DEFAULT_API_URL: str = "api.app.firebolt.io" PROTOCOL_VERSION_HEADER_NAME = "Firebolt-Protocol-Version" -PROTOCOL_VERSION: str = "2.0" +PROTOCOL_VERSION: str = "2.1" _REQUEST_ERRORS: Tuple[Type, ...] = ( HTTPError, InvalidURL, diff --git a/src/firebolt/common/base_cursor.py b/src/firebolt/common/base_cursor.py index ea40b869dc..65d9d2f3de 100644 --- a/src/firebolt/common/base_cursor.py +++ b/src/firebolt/common/base_cursor.py @@ -7,7 +7,7 @@ from types import TracebackType from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union -from httpx import Headers, Response +from httpx import URL, Response from firebolt.common._types import ( ColType, @@ -25,7 +25,7 @@ DataError, QueryNotRunError, ) -from firebolt.utils.util import Timer +from firebolt.utils.util import Timer, fix_url_schema logger = logging.getLogger(__name__) @@ -53,13 +53,32 @@ class QueryStatus(Enum): EXECUTION_ERROR = 8 -# known parameters that can be set on the server side -SERVER_SIDE_PARAMETERS = ["database"] - # Parameters that should be set using USE instead of SET USE_PARAMETER_LIST = ["database", "engine"] # parameters that can only be set by the backend DISALLOWED_PARAMETER_LIST = ["account_id", "output_format"] +# parameters that are set by the backend and should not be set by the user +IMMUTABLE_PARAMETER_LIST = USE_PARAMETER_LIST + DISALLOWED_PARAMETER_LIST + +UPDATE_ENDPOINT_HEADER = "Firebolt-Update-Endpoint" +UPDATE_PARAMETERS_HEADER = "Firebolt-Update-Parameters" +RESET_SESSION_HEADER = "Firebolt-Reset-Session" + + +def _parse_update_parameters(parameter_header: str) -> Dict[str, str]: + """Parse update parameters and set them as attributes.""" + # parse key1=value1,key2=value2 comma separated string into dict + param_dict = dict(item.split("=") for item in parameter_header.split(",")) + # strip whitespace from keys and values + param_dict = {key.strip(): value.strip() for key, value in param_dict.items()} + return param_dict + + +def _parse_update_endpoint( + new_engine_endpoint_header: str, +) -> Tuple[str, Dict[str, str]]: + endpoint = URL(fix_url_schema(new_engine_endpoint_header)) + return fix_url_schema(endpoint.host), dict(endpoint.params) def _raise_if_internal_set_parameter(parameter: SetParameter) -> None: @@ -150,6 +169,7 @@ class BaseCursor: "_next_set_idx", "_set_parameters", "_query_id", + "engine_url", ) default_arraysize = 1 @@ -168,14 +188,25 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: Optional[List[List[RawColType]]], ] ] = [] + # User-defined set parameters self._set_parameters: Dict[str, Any] = dict() + # Server-side parameters (user can't change them) self.parameters: Dict[str, str] = dict() + self.engine_url = "" self._rowcount = -1 self._idx = 0 self._next_set_idx = 0 self._query_id = "" self._reset() + @property + def database(self) -> Optional[str]: + return self.parameters.get("database") + + @database.setter + def database(self, database: str) -> None: + self.parameters["database"] = database + @property # type: ignore @check_not_closed def description(self) -> Optional[List[Column]]: @@ -273,25 +304,38 @@ def _reset(self) -> None: self._next_set_idx = 0 self._query_id = "" - def _parse_response_headers(self, headers: Headers) -> None: - """Parse response and update relevant cursor fields.""" - update_parameters = headers.get("Firebolt-Update-Parameters") - # parse update parameters dict and set keys as attributes - if update_parameters: - # parse key1=value1,key2=value2 comma separated string into dict - param_dict = dict(item.split("=") for item in update_parameters.split(",")) - # strip whitespace from keys and values - param_dict = { - key.strip(): value.strip() for key, value in param_dict.items() - } - for key, value in param_dict.items(): - if key in SERVER_SIDE_PARAMETERS: - self.parameters[key] = value - else: - logger.debug( - f"Unknown parameter {key} returned by the server. " - "It will be ignored." - ) + def _update_set_parameters(self, parameters: Dict[str, Any]) -> None: + # Split parameters into immutable and user parameters + immutable_parameters = { + key: value + for key, value in parameters.items() + if key in IMMUTABLE_PARAMETER_LIST + } + user_parameters = { + key: value + for key, value in parameters.items() + if key not in IMMUTABLE_PARAMETER_LIST + } + + self.parameters.update(immutable_parameters) + + self._set_parameters.update(user_parameters) + + def _update_server_parameters(self, parameters: Dict[str, Any]) -> None: + for key, value in parameters.items(): + self.parameters[key] = value + + @property + def engine_name(self) -> str: + """ + Get the name of the engine that we're using. + + Args: + engine_url (str): URL of the engine + """ + if self.parameters.get("engine"): + return self.parameters["engine"] + return URL(self.engine_url).host.split(".")[0].replace("-", "_") def _row_set_from_response( self, response: Response diff --git a/src/firebolt/common/constants.py b/src/firebolt/common/constants.py index cb48d28feb..7c5412ce9c 100644 --- a/src/firebolt/common/constants.py +++ b/src/firebolt/common/constants.py @@ -1,3 +1,6 @@ KEEPALIVE_FLAG: int = 1 KEEPIDLE_RATE: int = 60 # seconds DEFAULT_TIMEOUT_SECONDS: int = 60 + +# Running statuses in infromation schema +ENGINE_STATUS_RUNNING_LIST = ["Running", "ENGINE_STATE_RUNNING"] diff --git a/src/firebolt/db/connection.py b/src/firebolt/db/connection.py index 48e24a84fc..b023e81540 100644 --- a/src/firebolt/db/connection.py +++ b/src/firebolt/db/connection.py @@ -10,7 +10,10 @@ from firebolt.client import DEFAULT_API_URL, Client, ClientV1, ClientV2 from firebolt.client.auth import Auth from firebolt.common.base_connection import BaseConnection -from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS +from firebolt.common.constants import ( + DEFAULT_TIMEOUT_SECONDS, + ENGINE_STATUS_RUNNING_LIST, +) from firebolt.db.cursor import Cursor, CursorV1, CursorV2 from firebolt.db.util import _get_system_engine_url from firebolt.utils.exception import ( @@ -45,10 +48,10 @@ def connect( user_drivers = additional_parameters.get("user_drivers", []) user_clients = additional_parameters.get("user_clients", []) user_agent_header = get_user_agent_header(user_drivers, user_clients) - version = auth.get_firebolt_version() + auth_version = auth.get_firebolt_version() # Use v2 if auth is ClientCredentials # Use v1 if auth is ServiceAccount or UsernamePassword - if version == 2: + if auth_version == 2: assert account_name is not None return connect_v2( auth=auth, @@ -58,7 +61,7 @@ def connect( engine_name=engine_name, api_endpoint=api_endpoint, ) - elif version == 1: + elif auth_version == 1: return connect_v1( auth=auth, user_agent_header=user_agent_header, @@ -128,9 +131,27 @@ def connect_v2( None, api_endpoint, ) + + if system_engine_connection._client._account_version == 2: + cursor = system_engine_connection.cursor() + if database: + cursor.execute(f"USE DATABASE {database}") + if engine_name: + cursor.execute(f"USE ENGINE {engine_name}") + # Ensure cursors created from this conection are using the same starting + # database and engine + return Connection( + cursor.engine_url, + cursor.database, + client, + CursorV2, + system_engine_connection, + api_endpoint, + cursor.parameters, + ) + if not engine_name: return system_engine_connection - else: try: cursor = system_engine_connection.cursor() @@ -141,7 +162,7 @@ def connect_v2( attached_db, ) = cursor._get_engine_url_status_db(system_engine_connection, engine_name) - if status != "Running": + if status not in ENGINE_STATUS_RUNNING_LIST: raise EngineNotRunningError(engine_name) if database is not None and database != attached_db: @@ -215,6 +236,7 @@ def __init__( cursor_type: Type[Cursor], system_engine_connection: Optional["Connection"], api_endpoint: str = DEFAULT_API_URL, + init_parameters: Optional[Dict[str, Any]] = None, ): self.api_endpoint = api_endpoint self.engine_url = engine_url @@ -222,8 +244,8 @@ def __init__( self.cursor_type = cursor_type self._cursors: List[Cursor] = [] self._system_engine_connection = system_engine_connection - # Override tcp keepalive settings for connection self._client = client + self.init_parameters = init_parameters or {} super().__init__() def cursor(self, **kwargs: Any) -> Cursor: diff --git a/src/firebolt/db/cursor.py b/src/firebolt/db/cursor.py index 6117455b5a..71e7fc7039 100644 --- a/src/firebolt/db/cursor.py +++ b/src/firebolt/db/cursor.py @@ -15,7 +15,7 @@ Union, ) -from httpx import URL, Response, codes +from httpx import URL, Headers, Response, codes from firebolt.client import Client, ClientV1, ClientV2 from firebolt.common._types import ( @@ -28,15 +28,20 @@ ) from firebolt.common.base_cursor import ( JSON_OUTPUT_FORMAT, + RESET_SESSION_HEADER, + UPDATE_ENDPOINT_HEADER, + UPDATE_PARAMETERS_HEADER, BaseCursor, CursorState, QueryStatus, Statistics, + _parse_update_endpoint, + _parse_update_parameters, _raise_if_internal_set_parameter, check_not_closed, check_query_executed, ) -from firebolt.db.util import ENGINE_STATUS_RUNNING +from firebolt.common.constants import ENGINE_STATUS_RUNNING_LIST from firebolt.utils.exception import ( AsyncExecutionUnavailableError, EngineNotRunningError, @@ -78,16 +83,11 @@ def __init__( super().__init__(*args, **kwargs) self._client = client self.connection = connection + self.engine_url = connection.engine_url if connection.database: self.database = connection.database - - @property - def database(self) -> Optional[str]: - return self.parameters.get("database") - - @database.setter - def database(self, database: str) -> None: - self.parameters["database"] = database + if connection.init_parameters: + self._update_set_parameters(connection.init_parameters) def _raise_if_error(self, resp: Response) -> None: """Raise a proper error if any""" @@ -104,9 +104,9 @@ def _raise_if_error(self, resp: Response) -> None: if ( resp.status_code == codes.SERVICE_UNAVAILABLE or resp.status_code == codes.NOT_FOUND - ) and not self.is_engine_running(self.connection.engine_url): + ) and not self.is_engine_running(self.engine_url): raise EngineNotRunningError( - f"Firebolt engine {self.connection.engine_url} " + f"Firebolt engine {self.engine_name} " "needs to be running to run queries against it." # pragma: no mutate # noqa: E501 ) _print_error_body(resp) @@ -140,6 +140,30 @@ def _validate_set_parameter(self, parameter: SetParameter) -> None: # set parameter passed validation self._set_parameters[parameter.name] = parameter.value + def _parse_response_headers(self, headers: Headers) -> None: + if headers.get(UPDATE_ENDPOINT_HEADER): + endpoint, params = _parse_update_endpoint( + headers.get(UPDATE_ENDPOINT_HEADER) + ) + if ( + params.get("account_id", self._client.account_id) + != self._client.account_id + ): + raise OperationalError( + "USE ENGINE command failed. Account parameter mismatch. " + "Contact support" + ) + self._update_set_parameters(params) + self.engine_url = endpoint + self._client.base_url = URL(endpoint) + + if headers.get(RESET_SESSION_HEADER): + self.flush_parameters() + + if headers.get(UPDATE_PARAMETERS_HEADER): + param_dict = _parse_update_parameters(headers.get(UPDATE_PARAMETERS_HEADER)) + self._update_set_parameters(param_dict) + def _do_execute( self, raw_query: str, @@ -200,7 +224,6 @@ def _do_execute( query, {"output_format": JSON_OUTPUT_FORMAT} ) self._raise_if_error(resp) - # get parameters from response self._parse_response_headers(resp.headers) row_set = self._row_set_from_response(resp) @@ -395,7 +418,8 @@ def _api_request( parameters = {**(self._set_parameters or {}), **parameters} if self.parameters: parameters = {**self.parameters, **parameters} - if self.connection._is_system: + # Engines v2 always require account_id + if self.connection._is_system or self._client._account_version == 2: assert isinstance(self._client, ClientV2) # Type check parameters["account_id"] = self._client.account_id return self._client.request( @@ -438,12 +462,11 @@ def is_engine_running(self, engine_url: str) -> bool: # System engine is always running return True - engine_name = URL(engine_url).host.split(".")[0].replace("-", "_") assert self.connection._system_engine_connection is not None # Type check _, status, _ = self._get_engine_url_status_db( - self.connection._system_engine_connection, engine_name + self.connection._system_engine_connection, self.engine_name ) - return status == ENGINE_STATUS_RUNNING + return status in ENGINE_STATUS_RUNNING_LIST def _get_engine_url_status_db( self, system_engine_connection: Connection, engine_name: str diff --git a/src/firebolt/db/util.py b/src/firebolt/db/util.py index 93e2124eb9..4a399cbe27 100644 --- a/src/firebolt/db/util.py +++ b/src/firebolt/db/util.py @@ -11,8 +11,6 @@ ) from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME -ENGINE_STATUS_RUNNING = "Running" - def _get_system_engine_url( auth: Auth, diff --git a/tests/integration/dbapi/async/V2/conftest.py b/tests/integration/dbapi/async/V2/conftest.py index ae0b1e296f..d4117d37d3 100644 --- a/tests/integration/dbapi/async/V2/conftest.py +++ b/tests/integration/dbapi/async/V2/conftest.py @@ -61,13 +61,11 @@ async def connection_system_engine( @fixture async def connection_system_engine_v2( - database_name: str, auth: Auth, account_name_v2: str, api_endpoint: str, ) -> Connection: async with await connect( - database=database_name, auth=auth, account_name=account_name_v2, api_endpoint=api_endpoint, @@ -75,6 +73,19 @@ async def connection_system_engine_v2( yield connection +@fixture +async def engine_v2( + connection_system_engine_v2: Connection, + engine_name: str, +) -> str: + cursor = connection_system_engine_v2.cursor() + await cursor.execute(f"CREATE ENGINE IF NOT EXISTS {engine_name}") + await cursor.execute(f"START ENGINE {engine_name}") + yield engine_name + await cursor.execute(f"STOP ENGINE {engine_name}") + await cursor.execute(f"DROP ENGINE IF EXISTS {engine_name}") + + @fixture async def connection_system_engine_no_db( auth: Auth, diff --git a/tests/integration/dbapi/async/V2/test_queries_async.py b/tests/integration/dbapi/async/V2/test_queries_async.py index 0e7cda8d97..0768600ffc 100644 --- a/tests/integration/dbapi/async/V2/test_queries_async.py +++ b/tests/integration/dbapi/async/V2/test_queries_async.py @@ -2,13 +2,15 @@ from datetime import date, datetime from decimal import Decimal from os import environ -from random import choice -from typing import Callable, List +from random import choice, randint +from typing import Callable, Generator, List from pytest import fixture, mark, raises from firebolt.async_db import Binary, Connection, Cursor, OperationalError +from firebolt.async_db.connection import connect from firebolt.async_db.cursor import QueryStatus +from firebolt.client.auth.base import Auth from firebolt.common._types import ColType, Column from tests.integration.conftest import API_ENDPOINT_ENV from tests.integration.dbapi.utils import assert_deep_eq @@ -272,7 +274,7 @@ async def test_empty_query(c: Cursor, query: str, params: tuple) -> None: ) # \0 is converted to 0 - params[2] = "text0" + params[2] = "text\\0" assert ( await c.execute("SELECT * FROM test_tb_async_parameterized") == 1 @@ -280,7 +282,7 @@ async def test_empty_query(c: Cursor, query: str, params: tuple) -> None: assert_deep_eq( await c.fetchall(), - [params + ["?"]], + [params + ["\\?"]], "Invalid data in table after parameterized insert", ) @@ -409,6 +411,9 @@ async def test_bytea_roundtrip( ) -> None: """Inserted and than selected bytea value doesn't get corrupted.""" with connection.cursor() as c: + # Set standard_conforming_strings to 0 to allow bytea escape sequences + # FIR-30650 + await c.execute("SET standard_conforming_strings=0") await c.execute("DROP TABLE IF EXISTS test_bytea_roundtrip_2") await c.execute( "CREATE FACT TABLE test_bytea_roundtrip_2(id int, b bytea) primary index id" @@ -423,35 +428,34 @@ async def test_bytea_roundtrip( bytes_data = (await c.fetchone())[0] + await c.execute("SET standard_conforming_strings=1") assert ( bytes_data.decode("utf-8") == data ), "Invalid bytea data returned after roundtrip" @fixture -async def setup_db(connection_system_engine_no_db: Connection, use_db_name: str): +async def setup_v2_db(connection_system_engine_v2: Connection, use_db_name: str): use_db_name = use_db_name + "_async" - with connection_system_engine_no_db.cursor() as cursor: + with connection_system_engine_v2.cursor() as cursor: # randomize the db name to avoid conflicts suffix = "".join(choice("0123456789") for _ in range(2)) await cursor.execute(f"CREATE DATABASE {use_db_name}{suffix}") - yield + yield f"{use_db_name}{suffix}" await cursor.execute(f"DROP DATABASE {use_db_name}{suffix}") @mark.xfail("dev" not in environ[API_ENDPOINT_ENV], reason="Only works on dev") async def test_use_database( - setup_db, + setup_v2_db, connection_system_engine_no_db: Connection, - use_db_name: str, database_name: str, ) -> None: - test_db_name = use_db_name + "_async" test_table_name = "verify_use_db_async" """Use database works as expected.""" with connection_system_engine_no_db.cursor() as c: - await c.execute(f"USE DATABASE {test_db_name}") - assert c.database == test_db_name + await c.execute(f"USE DATABASE {setup_v2_db}") + assert c.database == setup_v2_db await c.execute(f"CREATE TABLE {test_table_name} (id int)") await c.execute( "SELECT table_name FROM information_schema.tables " @@ -466,3 +470,48 @@ async def test_use_database( f"WHERE table_name = '{test_table_name}'" ) assert (await c.fetchone()) is None, "Database was not changed" + + +async def test_account_v2_connection_with_db( + setup_v2_db: Generator, + auth: Auth, + account_name_v2: str, + api_endpoint: str, +) -> None: + async with await connect( + database=setup_v2_db, + auth=auth, + account_name=account_name_v2, + api_endpoint=api_endpoint, + ) as connection: + # This fails if we're not running with a db context + await connection.cursor().execute( + "SELECT * FROM information_schema.tables LIMIT 1" + ) + + +async def test_account_v2_connection_with_db_and_engine( + setup_v2_db: Generator, + connection_system_engine_v2: Connection, + auth: Auth, + account_name_v2: str, + api_endpoint: str, + engine_v2: str, +) -> None: + system_cursor = connection_system_engine_v2.cursor() + # We can only connect to a running engine so start it first + # via the system connection to keep test isolated + await system_cursor.execute(f"START ENGINE {engine_v2}") + async with await connect( + database=setup_v2_db, + engine_name=engine_v2, + auth=auth, + account_name=account_name_v2, + api_endpoint=api_endpoint, + ) as connection: + # generate a random string to avoid name conflicts + rnd_suffix = str(randint(0, 1000)) + cursor = connection.cursor() + await cursor.execute(f"CREATE TABLE test_table_{rnd_suffix} (id int)") + # This fails if we're not running on a user engine + await cursor.execute(f"INSERT INTO test_table_{rnd_suffix} VALUES (1)") diff --git a/tests/integration/dbapi/async/V2/test_system_engine_async.py b/tests/integration/dbapi/async/V2/test_system_engine_async.py index 67e6138b77..76eb8212c6 100644 --- a/tests/integration/dbapi/async/V2/test_system_engine_async.py +++ b/tests/integration/dbapi/async/V2/test_system_engine_async.py @@ -80,3 +80,21 @@ async def test_system_engine_v2_account(connection_system_engine_v2: Connection) assert ( await connection_system_engine_v2._client._account_version ) == 2, "Invalid account version" + + +async def test_system_engine_use_engine( + connection_system_engine_v2: Connection, database_name: str, engine_name: str +): + with connection_system_engine_v2.cursor() as cursor: + await cursor.execute(f"CREATE DATABASE IF NOT EXISTS {database_name}") + await cursor.execute(f"USE DATABASE {database_name}") + await cursor.execute(f"CREATE ENGINE IF NOT EXISTS {engine_name}") + await cursor.execute(f"USE ENGINE {engine_name}") + await cursor.execute("CREATE TABLE IF NOT EXISTS test_table (id int)") + # This query fails if we're not on a user engine + await cursor.execute("INSERT INTO test_table VALUES (1)") + await cursor.execute("USE ENGINE system") + # Werify we've switched to system by making previous query fail + with raises(OperationalError): + await cursor.execute("INSERT INTO test_table VALUES (1)") + await cursor.execute("DROP TABLE IF EXISTS test_table") diff --git a/tests/integration/dbapi/sync/V2/conftest.py b/tests/integration/dbapi/sync/V2/conftest.py index 351602a413..159ba7271b 100644 --- a/tests/integration/dbapi/sync/V2/conftest.py +++ b/tests/integration/dbapi/sync/V2/conftest.py @@ -59,15 +59,13 @@ def connection_system_engine( yield connection -@fixture +@fixture(scope="session") def connection_system_engine_v2( - database_name: str, auth: Auth, account_name_v2: str, api_endpoint: str, ) -> Connection: with connect( - database=database_name, auth=auth, account_name=account_name_v2, api_endpoint=api_endpoint, @@ -75,6 +73,19 @@ def connection_system_engine_v2( yield connection +@fixture(scope="session") +def engine_v2( + connection_system_engine_v2: Connection, + engine_name: str, +) -> str: + cursor = connection_system_engine_v2.cursor() + cursor.execute(f"CREATE ENGINE IF NOT EXISTS {engine_name}") + cursor.execute(f"START ENGINE {engine_name}") + yield engine_name + cursor.execute(f"STOP ENGINE {engine_name}") + cursor.execute(f"DROP ENGINE IF EXISTS {engine_name}") + + @fixture def connection_system_engine_no_db( auth: Auth, diff --git a/tests/integration/dbapi/sync/V2/test_queries.py b/tests/integration/dbapi/sync/V2/test_queries.py index 416c71369b..37de4bfa13 100644 --- a/tests/integration/dbapi/sync/V2/test_queries.py +++ b/tests/integration/dbapi/sync/V2/test_queries.py @@ -2,8 +2,9 @@ from datetime import date, datetime from decimal import Decimal from os import environ +from random import choice, randint from threading import Thread -from typing import Any, Callable, List +from typing import Any, Callable, Generator, List from pytest import fixture, mark, raises @@ -274,7 +275,7 @@ def test_empty_query(c: Cursor, query: str, params: tuple) -> None: ) # \0 is converted to 0 - params[2] = "text0" + params[2] = "text\\0" assert ( c.execute("SELECT * FROM test_tb_parameterized") == 1 @@ -282,7 +283,7 @@ def test_empty_query(c: Cursor, query: str, params: tuple) -> None: assert_deep_eq( c.fetchall(), - [params + ["?"]], + [params + ["\\?"]], "Invalid data in table after parameterized insert", ) @@ -493,6 +494,9 @@ def test_bytea_roundtrip( ) -> None: """Inserted and than selected bytea value doesn't get corrupted.""" with connection.cursor() as c: + # Set standard_conforming_strings to 0 to allow bytea escape sequences + # FIR-30650 + c.execute("SET standard_conforming_strings=0") c.execute("DROP TABLE IF EXISTS test_bytea_roundtrip") c.execute( "CREATE FACT TABLE test_bytea_roundtrip(id int, b bytea) primary index id" @@ -504,34 +508,34 @@ def test_bytea_roundtrip( c.execute("SELECT b FROM test_bytea_roundtrip") bytes_data = (c.fetchone())[0] - + c.execute("SET standard_conforming_strings=1") assert ( bytes_data.decode("utf-8") == data ), "Invalid bytea data returned after roundtrip" -@fixture -def setup_db(connection_system_engine, use_db_name): +@fixture(scope="module") +def setup_v2_db(connection_system_engine_v2, use_db_name): use_db_name = f"{use_db_name}_sync" - with connection_system_engine.cursor() as cursor: - cursor.execute(f"CREATE DATABASE {use_db_name}") - yield - cursor.execute(f"DROP DATABASE {use_db_name}") + with connection_system_engine_v2.cursor() as cursor: + # randomize the db name to avoid conflicts + suffix = "".join(choice("0123456789") for _ in range(2)) + cursor.execute(f"CREATE DATABASE {use_db_name}{suffix}") + yield f"{use_db_name}{suffix}" + cursor.execute(f"DROP DATABASE {use_db_name}{suffix}") @mark.xfail("dev" not in environ[API_ENDPOINT_ENV], reason="Only works on dev") def test_use_database( - setup_db, + setup_v2_db, connection_system_engine: Connection, - use_db_name: str, database_name: str, ) -> None: - test_db_name = f"{use_db_name}_sync" test_table_name = "verify_use_db" """Use database works as expected.""" with connection_system_engine.cursor() as c: - c.execute(f"USE DATABASE {test_db_name}") - assert c.database == test_db_name + c.execute(f"USE DATABASE {setup_v2_db}") + assert c.database == setup_v2_db c.execute(f"CREATE TABLE {test_table_name} (id int)") c.execute( "SELECT table_name FROM information_schema.tables " @@ -546,3 +550,46 @@ def test_use_database( f"WHERE table_name = '{test_table_name}'" ) assert c.fetchone() is None, "Database was not changed" + + +def test_account_v2_connection_with_db( + setup_v2_db: Generator, + auth: Auth, + account_name_v2: str, + api_endpoint: str, +) -> None: + with connect( + database=setup_v2_db, + auth=auth, + account_name=account_name_v2, + api_endpoint=api_endpoint, + ) as connection: + # This fails if we're not running with a db context + connection.cursor().execute("SELECT * FROM information_schema.tables LIMIT 1") + + +def test_account_v2_connection_with_db_and_engine( + setup_v2_db: Generator, + connection_system_engine_v2: Connection, + auth: Auth, + account_name_v2: str, + api_endpoint: str, + engine_v2: str, +) -> None: + system_cursor = connection_system_engine_v2.cursor() + # We can only connect to a running engine so start it first + # via the system connection to keep test isolated + system_cursor.execute(f"START ENGINE {engine_v2}") + with connect( + database=setup_v2_db, + engine_name=engine_v2, + auth=auth, + account_name=account_name_v2, + api_endpoint=api_endpoint, + ) as connection: + # generate a random string to avoid name conflicts + rnd_suffix = str(randint(0, 1000)) + cursor = connection.cursor() + cursor.execute(f"CREATE TABLE test_table_{rnd_suffix} (id int)") + # This fails if we're not running on a user engine + cursor.execute(f"INSERT INTO test_table_{rnd_suffix} VALUES (1)") diff --git a/tests/integration/dbapi/sync/V2/test_system_engine.py b/tests/integration/dbapi/sync/V2/test_system_engine.py index 9654dd2404..8dc3c8df28 100644 --- a/tests/integration/dbapi/sync/V2/test_system_engine.py +++ b/tests/integration/dbapi/sync/V2/test_system_engine.py @@ -80,3 +80,21 @@ def test_system_engine_v2_account(connection_system_engine_v2: Connection): assert ( connection_system_engine_v2._client._account_version == 2 ), "Invalid account version" + + +def test_system_engine_use_engine( + connection_system_engine_v2: Connection, database_name: str, engine_name: str +): + with connection_system_engine_v2.cursor() as cursor: + cursor.execute(f"CREATE DATABASE IF NOT EXISTS {database_name}") + cursor.execute(f"USE DATABASE {database_name}") + cursor.execute(f"CREATE ENGINE IF NOT EXISTS {engine_name}") + cursor.execute(f"USE ENGINE {engine_name}") + cursor.execute("CREATE TABLE IF NOT EXISTS test_table (id int)") + # This query fails if we're not on a user engine + cursor.execute("INSERT INTO test_table VALUES (1)") + cursor.execute("USE ENGINE system") + # Werify we've switched to system by making previous query fail + with raises(OperationalError): + cursor.execute("INSERT INTO test_table VALUES (1)") + cursor.execute("DROP TABLE IF EXISTS test_table") diff --git a/tests/unit/async_db/V2/conftest.py b/tests/unit/async_db/V2/conftest.py index d928c0e114..eff5d6affd 100644 --- a/tests/unit/async_db/V2/conftest.py +++ b/tests/unit/async_db/V2/conftest.py @@ -24,6 +24,8 @@ async def connection( api_endpoint=server, ) ) as connection: + # cache account_id for tests + await connection._client.account_id yield connection diff --git a/tests/unit/async_db/V2/test_cursor.py b/tests/unit/async_db/V2/test_cursor.py index f0062c65ba..85fdffee85 100644 --- a/tests/unit/async_db/V2/test_cursor.py +++ b/tests/unit/async_db/V2/test_cursor.py @@ -1,7 +1,7 @@ from typing import Any, Callable, Dict, List from unittest.mock import patch -from httpx import HTTPStatusError, StreamError, codes +from httpx import URL, HTTPStatusError, Request, StreamError, codes from pytest import LogCaptureFixture, mark, raises from pytest_httpx import HTTPXMock @@ -913,3 +913,120 @@ async def test_disallowed_set_parameter(cursor: Cursor, parameter: str) -> None: e.value ), "invalid error" assert cursor._set_parameters == {}, "set parameters should not be updated" + + +async def test_cursor_use_engine_no_parameters( + httpx_mock: HTTPXMock, + query_url: URL, + cursor: Cursor, + query_statistics: Dict[str, Any], +): + query_updated_url = "my_dummy_url" + + def query_callback_with_headers(request: Request, **kwargs) -> Response: + assert request.read() != b"" + assert request.method == "POST" + query_response = { + "meta": [{"name": "one", "type": "int"}], + "data": [1], + "rows": 1, + "statistics": query_statistics, + } + headers = {"Firebolt-Update-Endpoint": f"https://{query_updated_url}"} + return Response(status_code=codes.OK, json=query_response, headers=headers) + + httpx_mock.add_callback(query_callback_with_headers, url=query_url) + assert cursor.engine_url == "https://" + query_url.host + await cursor.execute("USE ENGINE = 'my_dummy_engine'") + assert cursor.engine_url == f"https://{query_updated_url}" + + httpx_mock.reset(True) + # Check updated engine is used in the next query + new_url = query_url.copy_with(host=query_updated_url) + httpx_mock.add_callback(query_callback_with_headers, url=new_url) + await cursor.execute("select 1") + assert cursor.engine_url == f"https://{query_updated_url}" + + +async def test_cursor_use_engine_with_parameters( + httpx_mock: HTTPXMock, + query_url: URL, + cursor: Cursor, + query_statistics: Dict[str, Any], +): + query_updated_url = "my_dummy_url" + param_string_dummy = "param1=1¶m2=2&engine=my_dummy_engine" + + header = { + "Firebolt-Update-Endpoint": f"https://{query_updated_url}/?{param_string_dummy}" + } + + def query_callback_with_headers(request: Request, **kwargs) -> Response: + assert request.read() != b"" + assert request.method == "POST" + query_response = { + "meta": [{"name": "one", "type": "int"}], + "data": [1], + "rows": 1, + "statistics": query_statistics, + } + headers = header + return Response(status_code=codes.OK, json=query_response, headers=headers) + + httpx_mock.add_callback(query_callback_with_headers, url=query_url) + assert cursor.engine_url == "https://" + query_url.host + await cursor.execute("USE ENGINE = 'my_dummy_engine'") + assert cursor.engine_url == f"https://{query_updated_url}" + assert cursor._set_parameters == {"param1": "1", "param2": "2"} + assert list(cursor.parameters.keys()) == ["database", "engine"] + assert cursor.engine_name == "my_dummy_engine" + + httpx_mock.reset(True) + # Check new parameters are used in the URL + new_url = query_url.copy_with(host=query_updated_url).copy_merge_params( + {"param1": "1", "param2": "2", "engine": "my_dummy_engine"} + ) + httpx_mock.add_callback(query_callback_with_headers, url=new_url) + await cursor.execute("select 1") + assert cursor.engine_url == f"https://{query_updated_url}" + + +async def test_cursor_reset_session( + httpx_mock: HTTPXMock, + select_one_query_callback: Callable, + set_query_url: str, + cursor: Cursor, + query_statistics: Dict[str, Any], +): + def query_callback_with_headers(request: Request, **kwargs) -> Response: + assert request.read() != b"" + assert request.method == "POST" + query_response = { + "meta": [{"name": "one", "type": "int"}], + "data": [1], + "rows": 1, + "statistics": query_statistics, + } + headers = {"Firebolt-Reset-Session": "any_value_here"} + return Response(status_code=codes.OK, json=query_response, headers=headers) + + httpx_mock.add_callback(select_one_query_callback, url=f"{set_query_url}&a=b") + + assert len(cursor._set_parameters) == 0 + + await cursor.execute("set a = b") + assert ( + len(cursor._set_parameters) == 1 + and "a" in cursor._set_parameters + and cursor._set_parameters["a"] == "b" + ) + + httpx_mock.reset(True) + httpx_mock.add_callback( + query_callback_with_headers, + url=f"{set_query_url}&a=b&output_format=JSON_Compact", + ) + await cursor.execute("SELECT 1") + assert len(cursor._set_parameters) == 0 + assert bool(cursor.engine_url) is True, "engine url is not set" + assert bool(cursor.database) is True, "database is not set" diff --git a/tests/unit/common/test_base_cursor.py b/tests/unit/common/test_base_cursor.py index 44ec305d1b..d78f91417a 100644 --- a/tests/unit/common/test_base_cursor.py +++ b/tests/unit/common/test_base_cursor.py @@ -1,4 +1,4 @@ -import logging +from typing import Dict from unittest.mock import MagicMock from pytest import fixture, mark @@ -13,27 +13,46 @@ def cursor(): return cursor +@fixture +def initial_parameters() -> Dict[str, str]: + return {"key1": "value1", "key2": "value2"} + + @mark.parametrize( - "headers, expected_parameters", + "set_params, expected", [ ( - {"Firebolt-Update-Parameters": "database=value1, key2=value2"}, - {"database": "value1"}, - ), - ( - {"Firebolt-Update-Parameters": "database = value1 ,key3= value3 "}, - {"database": "value1"}, + {"key2": "new_value2", "key3": "value3"}, + {"key1": "value1", "key2": "new_value2", "key3": "value3"}, ), + ({}, {"key1": "value1", "key2": "value2"}), ], ) -def test_parse_response_headers(headers, expected_parameters, cursor, caplog): - # Capture the debug messages - with caplog.at_level(logging.DEBUG, logger="firebolt.common.base_cursor"): - # Call the function with the mock headers - cursor._parse_response_headers(headers) - +def test_update_set_parameters( + set_params: Dict[str, str], + expected: Dict[str, str], + initial_parameters: Dict[str, str], + cursor: BaseCursor, +): + cursor._set_parameters = initial_parameters + cursor._update_set_parameters(set_params) # Assert that the parameters have been correctly updated - assert cursor.parameters == expected_parameters + assert cursor._set_parameters == expected + + +def test_flush_parameters(initial_parameters: Dict[str, str], cursor: BaseCursor): + cursor._set_parameters = initial_parameters + cursor.flush_parameters() + assert cursor._set_parameters == {} + + +def test_update_server_parameters_known_params( + initial_parameters: Dict[str, str], cursor: BaseCursor +): + cursor.parameters = initial_parameters + cursor._update_set_parameters({"database": "new_database"}) - # Check that the debug message has been logged - assert "Unknown parameter" in caplog.text + # Merge the dictionaries using the update() method + updated_parameters = initial_parameters.copy() + updated_parameters.update({"database": "new_database"}) + assert cursor.parameters == updated_parameters diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 0e9f9a088f..a1a845d54c 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -58,6 +58,11 @@ def client_secret() -> str: return "client_secret" +@fixture +def server_name() -> str: + return "api_dev" + + @fixture def server() -> str: return "api-dev.mock.firebolt.io" @@ -236,6 +241,11 @@ def engine_name() -> str: return "mock_engine_name" +@fixture +def engine_url(engine_name: str) -> str: + return f"{engine_name}.mock.firebolt.io" + + @fixture def get_engine_name_by_id_url(server: str, account_id: str, engine_id: str) -> str: return f"https://{server}" + ACCOUNT_ENGINE_URL.format( diff --git a/tests/unit/db/V2/conftest.py b/tests/unit/db/V2/conftest.py index 90bd7dc308..1bd687d8fd 100644 --- a/tests/unit/db/V2/conftest.py +++ b/tests/unit/db/V2/conftest.py @@ -23,6 +23,8 @@ def connection( account_name=account_name, api_endpoint=server, ) as connection: + # cache account_id for tests + connection._client.account_id yield connection diff --git a/tests/unit/db/V2/test_cursor.py b/tests/unit/db/V2/test_cursor.py index df2b091369..c0d3f1b8d0 100644 --- a/tests/unit/db/V2/test_cursor.py +++ b/tests/unit/db/V2/test_cursor.py @@ -1,7 +1,7 @@ from typing import Any, Callable, Dict, List from unittest.mock import patch -from httpx import HTTPStatusError, StreamError, codes +from httpx import URL, HTTPStatusError, Request, StreamError, codes from pytest import LogCaptureFixture, mark, raises from pytest_httpx import HTTPXMock @@ -189,7 +189,7 @@ def test_cursor_execute( def test_cursor_execute_error( httpx_mock: HTTPXMock, get_engines_url: str, - server: str, + server_name: str, db_name: str, query_url: str, query_statistics: Dict[str, Any], @@ -323,7 +323,7 @@ def http_error(*args, **kwargs): with raises(EngineNotRunningError) as excinfo: query() assert cursor._state == CursorState.ERROR - assert server in str(excinfo) + assert server_name in str(excinfo) # Engine does not exist httpx_mock.add_callback( @@ -818,3 +818,120 @@ def test_disallowed_set_parameter(cursor: Cursor, parameter: str) -> None: e.value ), "invalid error" assert cursor._set_parameters == {}, "set parameters should not be updated" + + +def test_cursor_use_engine_no_parameters( + httpx_mock: HTTPXMock, + query_url: URL, + cursor: Cursor, + query_statistics: Dict[str, Any], +): + query_updated_url = "my_dummy_url" + + def query_callback_with_headers(request: Request, **kwargs) -> Response: + assert request.read() != b"" + assert request.method == "POST" + query_response = { + "meta": [{"name": "one", "type": "int"}], + "data": [1], + "rows": 1, + "statistics": query_statistics, + } + headers = {"Firebolt-Update-Endpoint": f"https://{query_updated_url}"} + return Response(status_code=codes.OK, json=query_response, headers=headers) + + httpx_mock.add_callback(query_callback_with_headers, url=query_url) + assert cursor.engine_url == "https://" + query_url.host + cursor.execute("USE ENGINE = 'my_dummy_engine'") + assert cursor.engine_url == f"https://{query_updated_url}" + + httpx_mock.reset(True) + # Check updated engine is used in the next query + new_url = query_url.copy_with(host=query_updated_url) + httpx_mock.add_callback(query_callback_with_headers, url=new_url) + cursor.execute("select 1") + assert cursor.engine_url == f"https://{query_updated_url}" + + +def test_cursor_use_engine_with_parameters( + httpx_mock: HTTPXMock, + query_url: URL, + cursor: Cursor, + query_statistics: Dict[str, Any], +): + query_updated_url = "my_dummy_url" + param_string_dummy = "param1=1¶m2=2&engine=my_dummy_engine" + + header = { + "Firebolt-Update-Endpoint": f"https://{query_updated_url}/?{param_string_dummy}" + } + + def query_callback_with_headers(request: Request, **kwargs) -> Response: + assert request.read() != b"" + assert request.method == "POST" + query_response = { + "meta": [{"name": "one", "type": "int"}], + "data": [1], + "rows": 1, + "statistics": query_statistics, + } + headers = header + return Response(status_code=codes.OK, json=query_response, headers=headers) + + httpx_mock.add_callback(query_callback_with_headers, url=query_url) + assert cursor.engine_url == "https://" + query_url.host + cursor.execute("USE ENGINE = 'my_dummy_engine'") + assert cursor.engine_url == f"https://{query_updated_url}" + assert cursor._set_parameters == {"param1": "1", "param2": "2"} + assert list(cursor.parameters.keys()) == ["database", "engine"] + assert cursor.engine_name == "my_dummy_engine" + + httpx_mock.reset(True) + # Check new parameters are used in the URL + new_url = query_url.copy_with(host=query_updated_url).copy_merge_params( + {"param1": "1", "param2": "2", "engine": "my_dummy_engine"} + ) + httpx_mock.add_callback(query_callback_with_headers, url=new_url) + cursor.execute("select 1") + assert cursor.engine_url == f"https://{query_updated_url}" + + +def test_cursor_reset_session( + httpx_mock: HTTPXMock, + select_one_query_callback: Callable, + set_query_url: str, + cursor: Cursor, + query_statistics: Dict[str, Any], +): + def query_callback_with_headers(request: Request, **kwargs) -> Response: + assert request.read() != b"" + assert request.method == "POST" + query_response = { + "meta": [{"name": "one", "type": "int"}], + "data": [1], + "rows": 1, + "statistics": query_statistics, + } + headers = {"Firebolt-Reset-Session": "any_value_here"} + return Response(status_code=codes.OK, json=query_response, headers=headers) + + httpx_mock.add_callback(select_one_query_callback, url=f"{set_query_url}&a=b") + + assert len(cursor._set_parameters) == 0 + + cursor.execute("set a = b") + assert ( + len(cursor._set_parameters) == 1 + and "a" in cursor._set_parameters + and cursor._set_parameters["a"] == "b" + ) + + httpx_mock.reset(True) + httpx_mock.add_callback( + query_callback_with_headers, + url=f"{set_query_url}&a=b&output_format=JSON_Compact", + ) + cursor.execute("SELECT 1") + assert len(cursor._set_parameters) == 0 + assert bool(cursor.engine_url) is True, "engine url is not set" + assert bool(cursor.database) is True, "database is not set"