Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions src/firebolt/async_db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
from firebolt.client.auth import Auth
from firebolt.client.client import AsyncClient, AsyncClientV1, AsyncClientV2
from firebolt.common.base_connection import BaseConnection
from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS
from firebolt.common.constants import (
DEFAULT_TIMEOUT_SECONDS,
ENGINE_STATUS_RUNNING_LIST,
)
from firebolt.utils.exception import (
ConfigurationError,
ConnectionClosedError,
Expand Down Expand Up @@ -71,6 +74,7 @@ def __init__(
cursor_type: Type[Cursor],
system_engine_connection: Optional["Connection"],
api_endpoint: str,
init_parameters: Optional[Dict[str, Any]] = None,
):
super().__init__()
self.api_endpoint = api_endpoint
Expand All @@ -80,6 +84,7 @@ def __init__(
self._cursors: List[Cursor] = []
self._system_engine_connection = system_engine_connection
self._client = client
self.init_parameters = init_parameters

def cursor(self, **kwargs: Any) -> Cursor:
if self.closed:
Expand Down Expand Up @@ -142,8 +147,8 @@ async def connect(
user_agent_header = get_user_agent_header(user_drivers, user_clients)
# Use v2 if auth is ClientCredentials
# Use v1 if auth is ServiceAccount or UsernamePassword
version = auth.get_firebolt_version()
if version == 2:
auth_version = auth.get_firebolt_version()
if auth_version == 2:
assert account_name is not None
return await connect_v2(
auth=auth,
Expand All @@ -153,7 +158,7 @@ async def connect(
engine_name=engine_name,
api_endpoint=api_endpoint,
)
elif version == 1:
elif auth_version == 1:
return await connect_v1(
auth=auth,
user_agent_header=user_agent_header,
Expand Down Expand Up @@ -223,6 +228,26 @@ async def connect_v2(
None,
api_endpoint,
)

account_version = await system_engine_connection._client._account_version
if account_version == 2:
cursor = system_engine_connection.cursor()
if database:
await cursor.execute(f"USE DATABASE {database}")
if engine_name:
await cursor.execute(f"USE ENGINE {engine_name}")
# Ensure cursors created from this conection are using the same starting
# database and engine
return Connection(
cursor.engine_url,
cursor.database,
client,
CursorV2,
system_engine_connection,
api_endpoint,
cursor.parameters,
)

if not engine_name:
return system_engine_connection

Expand All @@ -237,7 +262,7 @@ async def connect_v2(
attached_db,
) = await cursor._get_engine_url_status_db(engine_name)

if status != "Running":
if status not in ENGINE_STATUS_RUNNING_LIST:
raise EngineNotRunningError(engine_name)

if database is not None and database != attached_db:
Expand Down
65 changes: 43 additions & 22 deletions src/firebolt/async_db/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,9 @@
Union,
)

from httpx import URL
from httpx import AsyncClient as HttpxAsyncClient
from httpx import Response, codes
from httpx import URL, Headers, Response, codes

from firebolt.async_db.util import ENGINE_STATUS_RUNNING
from firebolt.client.client import AsyncClientV1, AsyncClientV2
from firebolt.client.client import AsyncClient, AsyncClientV1, AsyncClientV2
from firebolt.common._types import (
ColType,
Column,
Expand All @@ -33,14 +30,20 @@
)
from firebolt.common.base_cursor import (
JSON_OUTPUT_FORMAT,
RESET_SESSION_HEADER,
UPDATE_ENDPOINT_HEADER,
UPDATE_PARAMETERS_HEADER,
BaseCursor,
CursorState,
QueryStatus,
Statistics,
_parse_update_endpoint,
_parse_update_parameters,
_raise_if_internal_set_parameter,
check_not_closed,
check_query_executed,
)
from firebolt.common.constants import ENGINE_STATUS_RUNNING_LIST
from firebolt.utils.exception import (
AsyncExecutionUnavailableError,
EngineNotRunningError,
Expand Down Expand Up @@ -76,23 +79,18 @@ class Cursor(BaseCursor, metaclass=ABCMeta):
def __init__(
self,
*args: Any,
client: HttpxAsyncClient,
client: AsyncClient,
connection: Connection,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self._client = client
self.connection = connection
self.engine_url = connection.engine_url
if connection.database:
self.database = connection.database

@property
def database(self) -> Optional[str]:
return self.parameters.get("database")

@database.setter
def database(self, database: str) -> None:
self.parameters["database"] = database
if connection.init_parameters:
self._update_set_parameters(connection.init_parameters)

@abstractmethod
async def _api_request(
Expand All @@ -117,9 +115,9 @@ async def _raise_if_error(self, resp: Response) -> None:
if (
resp.status_code == codes.SERVICE_UNAVAILABLE
or resp.status_code == codes.NOT_FOUND
) and not await self.is_engine_running(self.connection.engine_url):
) and not await self.is_engine_running(self.engine_url):
raise EngineNotRunningError(
f"Firebolt engine {self.connection.engine_url} "
f"Firebolt engine {self.engine_url} "
"needs to be running to run queries against it."
)
_print_error_body(resp)
Expand All @@ -143,6 +141,30 @@ async def _validate_set_parameter(self, parameter: SetParameter) -> None:
# set parameter passed validation
self._set_parameters[parameter.name] = parameter.value

async def _parse_response_headers(self, headers: Headers) -> None:
if headers.get(UPDATE_ENDPOINT_HEADER):
endpoint, params = _parse_update_endpoint(
headers.get(UPDATE_ENDPOINT_HEADER)
)
if (
params.get("account_id", await self._client.account_id)
!= await self._client.account_id
):
raise OperationalError(
"USE ENGINE command failed. Account parameter mismatch. "
"Contact support"
)
self._update_set_parameters(params)
self.engine_url = endpoint
self._client.base_url = URL(endpoint)

if headers.get(RESET_SESSION_HEADER):
self.flush_parameters()

if headers.get(UPDATE_PARAMETERS_HEADER):
param_dict = _parse_update_parameters(headers.get(UPDATE_PARAMETERS_HEADER))
self._update_set_parameters(param_dict)

async def _do_execute(
self,
raw_query: str,
Expand Down Expand Up @@ -209,8 +231,7 @@ async def _do_execute(
query, {"output_format": JSON_OUTPUT_FORMAT}
)
await self._raise_if_error(resp)
# get parameters from response
self._parse_response_headers(resp.headers)
await self._parse_response_headers(resp.headers)
row_set = self._row_set_from_response(resp)

self._append_row_set(row_set)
Expand Down Expand Up @@ -452,7 +473,8 @@ async def _api_request(
parameters = {**(self._set_parameters or {}), **parameters}
if self.parameters:
parameters = {**self.parameters, **parameters}
if self.connection._is_system:
# Engines v2 always require account_id
if self.connection._is_system or (await self._client._account_version) == 2:
assert isinstance(self._client, AsyncClientV2)
parameters["account_id"] = await self._client.account_id
return await self._client.request(
Expand Down Expand Up @@ -495,16 +517,15 @@ async def is_engine_running(self, engine_url: str) -> bool:
# System engine is always running
return True

engine_name = URL(engine_url).host.split(".")[0].replace("-", "_")
assert self.connection._system_engine_connection is not None # Type check
system_cursor = self.connection._system_engine_connection.cursor()
assert isinstance(system_cursor, CursorV2) # Type check, should always be true
(
_,
status,
_,
) = await system_cursor._get_engine_url_status_db(engine_name)
return status == ENGINE_STATUS_RUNNING
) = await system_cursor._get_engine_url_status_db(self.engine_name)
return status in ENGINE_STATUS_RUNNING_LIST

async def _get_engine_url_status_db(self, engine_name: str) -> Tuple[str, str, str]:
await self.execute(
Expand Down
2 changes: 0 additions & 2 deletions src/firebolt/async_db/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
)
from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME

ENGINE_STATUS_RUNNING = "Running"


async def _get_system_engine_url(
auth: Auth,
Expand Down
2 changes: 1 addition & 1 deletion src/firebolt/client/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

DEFAULT_API_URL: str = "api.app.firebolt.io"
PROTOCOL_VERSION_HEADER_NAME = "Firebolt-Protocol-Version"
PROTOCOL_VERSION: str = "2.0"
PROTOCOL_VERSION: str = "2.1"
_REQUEST_ERRORS: Tuple[Type, ...] = (
HTTPError,
InvalidURL,
Expand Down
92 changes: 68 additions & 24 deletions src/firebolt/common/base_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from types import TracebackType
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

from httpx import Headers, Response
from httpx import URL, Response

from firebolt.common._types import (
ColType,
Expand All @@ -25,7 +25,7 @@
DataError,
QueryNotRunError,
)
from firebolt.utils.util import Timer
from firebolt.utils.util import Timer, fix_url_schema

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -53,13 +53,32 @@ class QueryStatus(Enum):
EXECUTION_ERROR = 8


# known parameters that can be set on the server side
SERVER_SIDE_PARAMETERS = ["database"]

# Parameters that should be set using USE instead of SET
USE_PARAMETER_LIST = ["database", "engine"]
# parameters that can only be set by the backend
DISALLOWED_PARAMETER_LIST = ["account_id", "output_format"]
# parameters that are set by the backend and should not be set by the user
IMMUTABLE_PARAMETER_LIST = USE_PARAMETER_LIST + DISALLOWED_PARAMETER_LIST

UPDATE_ENDPOINT_HEADER = "Firebolt-Update-Endpoint"
UPDATE_PARAMETERS_HEADER = "Firebolt-Update-Parameters"
RESET_SESSION_HEADER = "Firebolt-Reset-Session"


def _parse_update_parameters(parameter_header: str) -> Dict[str, str]:
"""Parse update parameters and set them as attributes."""
# parse key1=value1,key2=value2 comma separated string into dict
param_dict = dict(item.split("=") for item in parameter_header.split(","))
# strip whitespace from keys and values
param_dict = {key.strip(): value.strip() for key, value in param_dict.items()}
return param_dict


def _parse_update_endpoint(
new_engine_endpoint_header: str,
) -> Tuple[str, Dict[str, str]]:
endpoint = URL(fix_url_schema(new_engine_endpoint_header))
return fix_url_schema(endpoint.host), dict(endpoint.params)


def _raise_if_internal_set_parameter(parameter: SetParameter) -> None:
Expand Down Expand Up @@ -150,6 +169,7 @@ class BaseCursor:
"_next_set_idx",
"_set_parameters",
"_query_id",
"engine_url",
)

default_arraysize = 1
Expand All @@ -168,14 +188,25 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
Optional[List[List[RawColType]]],
]
] = []
# User-defined set parameters
self._set_parameters: Dict[str, Any] = dict()
# Server-side parameters (user can't change them)
self.parameters: Dict[str, str] = dict()
self.engine_url = ""
self._rowcount = -1
self._idx = 0
self._next_set_idx = 0
self._query_id = ""
self._reset()

@property
def database(self) -> Optional[str]:
return self.parameters.get("database")

@database.setter
def database(self, database: str) -> None:
self.parameters["database"] = database

@property # type: ignore
@check_not_closed
def description(self) -> Optional[List[Column]]:
Expand Down Expand Up @@ -273,25 +304,38 @@ def _reset(self) -> None:
self._next_set_idx = 0
self._query_id = ""

def _parse_response_headers(self, headers: Headers) -> None:
"""Parse response and update relevant cursor fields."""
update_parameters = headers.get("Firebolt-Update-Parameters")
# parse update parameters dict and set keys as attributes
if update_parameters:
# parse key1=value1,key2=value2 comma separated string into dict
param_dict = dict(item.split("=") for item in update_parameters.split(","))
# strip whitespace from keys and values
param_dict = {
key.strip(): value.strip() for key, value in param_dict.items()
}
for key, value in param_dict.items():
if key in SERVER_SIDE_PARAMETERS:
self.parameters[key] = value
else:
logger.debug(
f"Unknown parameter {key} returned by the server. "
"It will be ignored."
)
def _update_set_parameters(self, parameters: Dict[str, Any]) -> None:
# Split parameters into immutable and user parameters
immutable_parameters = {
key: value
for key, value in parameters.items()
if key in IMMUTABLE_PARAMETER_LIST
}
user_parameters = {
key: value
for key, value in parameters.items()
if key not in IMMUTABLE_PARAMETER_LIST
}

self.parameters.update(immutable_parameters)

self._set_parameters.update(user_parameters)

def _update_server_parameters(self, parameters: Dict[str, Any]) -> None:
for key, value in parameters.items():
self.parameters[key] = value

@property
def engine_name(self) -> str:
"""
Get the name of the engine that we're using.

Args:
engine_url (str): URL of the engine
"""
if self.parameters.get("engine"):
return self.parameters["engine"]
return URL(self.engine_url).host.split(".")[0].replace("-", "_")

def _row_set_from_response(
self, response: Response
Expand Down
Loading