Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
05defb6
cleanup and rename auth classes
stepansergeevitch Mar 10, 2023
935721a
fix async_db tests
stepansergeevitch Mar 10, 2023
bf509d0
fix db unit tests
stepansergeevitch Mar 10, 2023
a5bb82c
fix remaining tests
stepansergeevitch Mar 10, 2023
82aa976
update auth flow logic
stepansergeevitch Mar 14, 2023
51703b6
update connection logic
stepansergeevitch Mar 20, 2023
b628aad
make database optional in connection
stepansergeevitch Mar 20, 2023
b5e69ec
fix sync connection tests
stepansergeevitch Mar 21, 2023
83a5b44
fix sync cursor tests
stepansergeevitch Mar 21, 2023
eb66f78
fixes for async cursor tests
stepansergeevitch Mar 21, 2023
f2cc825
fix engine get connection function
stepansergeevitch Mar 21, 2023
21b2b89
fix nested loops in trio
stepansergeevitch Mar 21, 2023
af858a4
update connection logic, fix engine running check
stepansergeevitch Mar 22, 2023
08f9d91
unit tests fixes
stepansergeevitch Mar 22, 2023
30bf6a6
integration testing WIP
stepansergeevitch Apr 14, 2023
e303f8c
WIP update audience
stepansergeevitch May 10, 2023
e97b9ea
WIP main rebase
stepansergeevitch May 25, 2023
91cc609
fix async unit tests
stepansergeevitch May 25, 2023
a551d2c
fix merge unit test issues
stepansergeevitch May 25, 2023
3580143
add account_id to system engine, minor fixes
stepansergeevitch May 30, 2023
92f0b30
fix async tests
stepansergeevitch May 31, 2023
bf8fd5e
fix db unit tests
stepansergeevitch May 31, 2023
d08b08e
fix integration tests
stepansergeevitch Jun 7, 2023
5b7a711
Merge branch 'main' into FIR-21171-new-identity-support
stepansergeevitch Jun 7, 2023
8afbb38
update trio version
stepansergeevitch Jun 7, 2023
19a225f
address comments
stepansergeevitch Jun 14, 2023
0352ae1
update pytest
stepansergeevitch Jun 14, 2023
6c57e8e
fix code checks
stepansergeevitch Jun 14, 2023
0239c0a
add missing requirement
stepansergeevitch Jun 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docsrc/firebolt.common.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Util
---------------------------

.. automodule:: firebolt.common.util
:exclude-members: async_to_sync, cached_property, fix_url_schema, mixin_for, prune_dict
:exclude-members: cached_property, fix_url_schema, mixin_for, prune_dict
:members:
:undoc-members:
:show-inheritance:
12 changes: 7 additions & 5 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@ install_requires =
aiorwlock==1.1.0
appdirs>=1.4.4
appdirs-stubs>=0.1.0
async-property>=0.2.1
cryptography>=3.4.0
httpx[http2]==0.24.0
pydantic[dotenv]>=1.8.2
python-dateutil>=2.8.2
readerwriterlock==1.0.9
readerwriterlock>=1.0.9
sqlparse>=0.4.2
trio<0.22.0
tricycle>=0.2.2
trio>=0.22.0
python_requires = >=3.7
include_package_data = True
package_dir =
Expand All @@ -50,12 +52,12 @@ dev =
mypy==0.910
pre-commit==2.15.0
pyfakefs>=4.5.3
pytest==6.2.5
pytest-asyncio==0.19.0
pytest==7.2.0
pytest-cov==3.0.0
pytest-httpx==0.22.0
pytest-mock==3.6.1
pytest-timeout==2.1.0
pytest-trio==0.8.0
pytest-xdist==2.5.0
trio-typing[mypy]==0.6.*
types-cryptography==3.3.18
Expand Down Expand Up @@ -90,4 +92,4 @@ docstring-convention = google
inline-quotes = "

[tool:pytest]
asyncio_mode = auto
trio_mode = true
238 changes: 75 additions & 163 deletions src/firebolt/async_db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,128 +2,38 @@

import logging
import socket
from json import JSONDecodeError
from types import TracebackType
from typing import Any, Dict, List, Optional

from httpcore.backends.auto import AutoBackend
from httpcore.backends.base import AsyncNetworkStream
from httpx import AsyncHTTPTransport, HTTPStatusError, RequestError, Timeout
from httpx import AsyncHTTPTransport, Timeout

from firebolt.async_db.cursor import Cursor
from firebolt.async_db.util import (
_get_engine_url_status_db,
_get_system_engine_url,
)
from firebolt.client import DEFAULT_API_URL, AsyncClient
from firebolt.client.auth import Auth, _get_auth
from firebolt.client.auth import Auth
from firebolt.common.base_connection import BaseConnection
from firebolt.common.settings import (
DEFAULT_TIMEOUT_SECONDS,
KEEPALIVE_FLAG,
KEEPIDLE_RATE,
)
from firebolt.common.util import validate_engine_name_and_url
from firebolt.utils.exception import (
ConfigurationError,
ConnectionClosedError,
FireboltEngineError,
EngineNotRunningError,
InterfaceError,
)
from firebolt.utils.urls import (
ACCOUNT_ENGINE_ID_BY_NAME_URL,
ACCOUNT_ENGINE_URL,
ACCOUNT_ENGINE_URL_BY_DATABASE_NAME,
)
from firebolt.utils.usage_tracker import get_user_agent_header
from firebolt.utils.util import fix_url_schema

AUTH_CREDENTIALS_DEPRECATION_MESSAGE = """ Passing connection credentials
directly to the `connect` function is deprecated.
Pass the `Auth` object instead.
Examples:
>>> from firebolt.client.auth import UsernamePassword
>>> ...
>>> connect(auth=UsernamePassword(username, password), ...)
or
>>> from firebolt.client.auth import Token
>>> ...
>>> connect(auth=Token(access_token), ...)"""

logger = logging.getLogger(__name__)


async def _resolve_engine_url(
engine_name: str,
auth: Auth,
api_endpoint: str,
account_name: Optional[str] = None,
) -> str:
async with AsyncClient(
auth=auth,
base_url=api_endpoint,
account_name=account_name,
api_endpoint=api_endpoint,
timeout=Timeout(DEFAULT_TIMEOUT_SECONDS),
) as client:
account_id = await client.account_id
url = ACCOUNT_ENGINE_ID_BY_NAME_URL.format(account_id=account_id)
try:
response = await client.get(
url=url,
params={"engine_name": engine_name},
)
response.raise_for_status()
engine_id = response.json()["engine_id"]["engine_id"]
url = ACCOUNT_ENGINE_URL.format(account_id=account_id, engine_id=engine_id)
response = await client.get(url=url)
response.raise_for_status()
return response.json()["engine"]["endpoint"]
except HTTPStatusError as e:
# Engine error would be 404.
if e.response.status_code != 404:
raise InterfaceError(
f"Error {e.__class__.__name__}: Unable to retrieve engine "
f"endpoint {url}."
)
# Once this is point is reached we've already authenticated with
# the backend so it's safe to assume the cause of the error is
# missing engine.
raise FireboltEngineError(f"Firebolt engine {engine_name} does not exist.")
except (JSONDecodeError, RequestError, RuntimeError) as e:
raise InterfaceError(
f"Error {e.__class__.__name__}: "
f"Unable to retrieve engine endpoint {url}."
)


async def _get_database_default_engine_url(
database: str,
auth: Auth,
api_endpoint: str,
account_name: Optional[str] = None,
) -> str:
async with AsyncClient(
auth=auth,
base_url=api_endpoint,
account_name=account_name,
api_endpoint=api_endpoint,
timeout=Timeout(DEFAULT_TIMEOUT_SECONDS),
) as client:
try:
account_id = await client.account_id
response = await client.get(
url=ACCOUNT_ENGINE_URL_BY_DATABASE_NAME.format(account_id=account_id),
params={"database_name": database},
)
response.raise_for_status()
return response.json()["engine_url"]
except (
JSONDecodeError,
RequestError,
RuntimeError,
HTTPStatusError,
KeyError,
) as e:
raise InterfaceError(f"Unable to retrieve default engine endpoint: {e}.")


class OverriddenHttpBackend(AutoBackend):
"""
`OverriddenHttpBackend` is a short-term solution for the TCP
Expand Down Expand Up @@ -189,24 +99,27 @@ class Connection(BaseConnection):

"""

client_class: type
__slots__ = (
"_client",
"_cursors",
"database",
"engine_url",
"api_endpoint",
"_is_closed",
"_system_engine_connection",
)

def __init__(
self,
engine_url: str,
database: str,
database: Optional[str],
auth: Auth,
api_endpoint: str = DEFAULT_API_URL,
account_name: str,
system_engine_connection: Optional["Connection"],
api_endpoint: str,
additional_parameters: Dict[str, Any] = {},
):
super().__init__()
self.api_endpoint = api_endpoint
self.engine_url = engine_url
self.database = database
Expand All @@ -217,14 +130,15 @@ def __init__(
user_drivers = additional_parameters.get("user_drivers", [])
user_clients = additional_parameters.get("user_clients", [])
self._client = AsyncClient(
account_name=account_name,
auth=auth,
base_url=engine_url,
api_endpoint=api_endpoint,
timeout=Timeout(DEFAULT_TIMEOUT_SECONDS, read=None),
transport=transport,
headers={"User-Agent": get_user_agent_header(user_drivers, user_clients)},
)
super().__init__()
self._system_engine_connection = system_engine_connection

def cursor(self, **kwargs: Any) -> Cursor:
if self.closed:
Expand Down Expand Up @@ -256,96 +170,94 @@ async def aclose(self) -> None:
await self._client.aclose()
self._is_closed = True

if self._system_engine_connection:
await self._system_engine_connection.aclose()

async def __aexit__(
self, exc_type: type, exc_val: Exception, exc_tb: TracebackType
) -> None:
await self.aclose()


async def connect(
database: str = None,
username: Optional[str] = None,
password: Optional[str] = None,
access_token: Optional[str] = None,
auth: Auth = None,
engine_name: Optional[str] = None,
engine_url: Optional[str] = None,
auth: Optional[Auth] = None,
account_name: Optional[str] = None,
database: Optional[str] = None,
engine_name: Optional[str] = None,
api_endpoint: str = DEFAULT_API_URL,
use_token_cache: bool = True,
additional_parameters: Dict[str, Any] = {},
) -> Connection:
"""Connect to Firebolt database.
"""Connect to Firebolt.

Args:
`auth` (Auth) Authentication object
`database` (str): Name of the database to connect
`username` (Optional[str]): User name to use for authentication (Deprecated)
`password` (Optional[str]): Password to use for authentication (Deprecated)
`access_token` (Optional[str]): Authentication token to use instead of
credentials (Deprecated)
`auth` (Auth)L Authentication object.
`engine_name` (Optional[str]): Name of the engine to connect to
`engine_url` (Optional[str]): The engine endpoint to use
`account_name` (Optional[str]): For customers with multiple accounts;
if none, default is used
`api_endpoint` (str): Firebolt API endpoint. Used for authentication
`use_token_cache` (bool): Cached authentication token in filesystem
Default: True
`additional_parameters` (Optional[Dict]): Dictionary of less widely-used
arguments for connection

Note:
Providing both `engine_name` and `engine_url` will result in an error

"""
# These parameters are optional in function signature
# but are required to connect.
# PEP 249 recommends making them kwargs.
if not database:
raise ConfigurationError("database name is required to connect.")
for name, value in (("auth", auth), ("account_name", account_name)):
if not value:
raise ConfigurationError(f"{name} is required to connect.")

validate_engine_name_and_url(engine_name, engine_url)
# Type checks
assert auth is not None
assert account_name is not None

if not auth:
if any([username, password, access_token, api_endpoint, use_token_cache]):
logger.warning(AUTH_CREDENTIALS_DEPRECATION_MESSAGE)
auth = _get_auth(username, password, access_token, use_token_cache)
else:
raise ConfigurationError("No authentication provided.")
api_endpoint = fix_url_schema(api_endpoint)

# Mypy checks, this should never happen
assert database is not None
system_engine_url = fix_url_schema(
await _get_system_engine_url(auth, account_name, api_endpoint)
)
# Don't use context manager since this will be stored
# and used in a resulting connection
system_engine_connection = Connection(
system_engine_url,
database,
auth,
account_name,
None,
api_endpoint,
additional_parameters,
)

if not engine_name and not engine_url:
engine_url = await _get_database_default_engine_url(
database=database,
auth=auth,
account_name=account_name,
api_endpoint=api_endpoint,
)
if not engine_name:
return system_engine_connection

elif engine_name:
engine_url = await _resolve_engine_url(
engine_name=engine_name,
auth=auth,
account_name=account_name,
api_endpoint=api_endpoint,
)
elif account_name:
# In above if branches account name is validated since it's used to
# resolve or get an engine url.
# We need to manually validate account_name if none of the above
# cases are triggered.
async with AsyncClient(
auth=auth,
base_url=api_endpoint,
account_name=account_name,
api_endpoint=api_endpoint,
) as client:
await client.account_id
else:
try:
engine_url, status, attached_db = await _get_engine_url_status_db(
system_engine_connection, engine_name
)

assert engine_url is not None
if status != "Running":
raise EngineNotRunningError(engine_name)

engine_url = fix_url_schema(engine_url)
return Connection(engine_url, database, auth, api_endpoint, additional_parameters)
if database is not None and database != attached_db:
raise InterfaceError(
f"Engine {engine_name} is not attached to {database}, "
f"but to {attached_db}"
)
elif database is None:
database = attached_db

assert engine_url is not None

engine_url = fix_url_schema(engine_url)
return Connection(
engine_url,
database,
auth,
account_name,
system_engine_connection,
api_endpoint,
additional_parameters,
)
except: # noqa
await system_engine_connection.aclose()
raise
Loading