diff --git a/src/firebolt/db/__init__.py b/src/firebolt/db/__init__.py index 9a20b4636b8..7c7af17cbdb 100644 --- a/src/firebolt/db/__init__.py +++ b/src/firebolt/db/__init__.py @@ -13,7 +13,7 @@ Timestamp, TimestampFromTicks, ) -from firebolt.db.connection import Connection +from firebolt.db.connection import Connection, connect from firebolt.db.cursor import Cursor apilevel = "2.0" diff --git a/src/firebolt/db/connection.py b/src/firebolt/db/connection.py index 66b91bd6439..8b1fbf88abd 100644 --- a/src/firebolt/db/connection.py +++ b/src/firebolt/db/connection.py @@ -2,31 +2,97 @@ from inspect import cleandoc from types import TracebackType -from typing import List +from typing import List, Optional from httpx import Timeout from readerwriterlock.rwlock import RWLockWrite from firebolt.client import DEFAULT_API_URL, Client -from firebolt.common.exception import ConnectionClosedError +from firebolt.common.exception import ConnectionClosedError, InterfaceError +from firebolt.common.settings import Settings from firebolt.db.cursor import Cursor +from firebolt.service.manager import ResourceManager DEFAULT_TIMEOUT_SECONDS: int = 5 +def connect( + database: str = None, + username: str = None, + password: str = None, + engine_name: Optional[str] = None, + engine_url: Optional[str] = None, + api_endpoint: str = DEFAULT_API_URL, +) -> Connection: + cleandoc( + """ + Connect to Firebolt database. + + Connection parameters: + database - name of the database to connect + username - user name to use for authentication + password - password to use for authentication + engine_name - name of the engine to connect to + engine_url - engine endpoint to use + note: either engine_name or engine_url should be provided, but not both + """ + ) + if engine_name and engine_url: + raise InterfaceError( + "Both engine_name and engine_url are provided. Provide only one to connect." + ) + if not engine_name and not engine_url: + raise InterfaceError( + "Neither engine_name nor engine_url are provided. Provide one to connect." + ) + # This parameters are optional in function signature, but are required to connect. + # It's recomended to make them kwargs by PEP 249 + for param, name in ( + (database, "database"), + (username, "username"), + (password, "password"), + ): + if not param: + raise InterfaceError(f"{name} is required to connect.") + + if engine_name is not None: + rm = ResourceManager( + Settings( + user=username, password=password, server=api_endpoint, default_region="" + ) + ) + endpoint = rm.engines.get_by_name(engine_name).endpoint + if endpoint is None: + raise InterfaceError("unable to retrieve engine endpoint.") + else: + engine_url = endpoint + + # Mypy checks, this should never happen + assert engine_url is not None + assert database is not None + assert username is not None + assert password is not None + + engine_url = ( + engine_url if engine_url.startswith("http") else f"https://{engine_url}" + ) + return Connection(engine_url, database, username, password, api_endpoint) + + class Connection: cleandoc( """ Firebolt database connection class. Implements PEP-249. Parameters: + engine_url - Firebolt database engine REST API url + database - Firebolt database name username - Firebolt account username password - Firebolt account password - engine_url - Firebolt database engine REST API url api_endpoint(optional) - Firebolt API endpoint. Used for authentication Methods: - cursor - created new Cursor object + cursor - create new Cursor object close - close the Connection and all it's cursors Firebolt currenly doesn't support transactions so commit and rollback methods @@ -43,9 +109,6 @@ def __init__( password: str, api_endpoint: str = DEFAULT_API_URL, ): - engine_url = ( - engine_url if engine_url.startswith("http") else f"https://{engine_url}" - ) self._client = Client( auth=(username, password), base_url=engine_url, diff --git a/src/firebolt/db/examples.ipynb b/src/firebolt/db/examples.ipynb index 301283725e4..22ef690bec6 100644 --- a/src/firebolt/db/examples.ipynb +++ b/src/firebolt/db/examples.ipynb @@ -15,7 +15,7 @@ "metadata": {}, "outputs": [], "source": [ - "from firebolt.db import Connection\n", + "from firebolt.db import connect\n", "from firebolt.client import DEFAULT_API_URL" ] }, @@ -34,8 +34,14 @@ "metadata": {}, "outputs": [], "source": [ - "database_name = \"\"\n", + "# Only one of these two parameters should be specified\n", "engine_url = \"\"\n", + "engine_name = \"\"\n", + "assert bool(engine_url) != bool(\n", + " engine_name\n", + "), \"Specify only one of engine_name and engine_url\"\n", + "\n", + "database_name = \"\"\n", "username = \"\"\n", "password = \"\"\n", "api_endpoint = DEFAULT_API_URL" @@ -57,8 +63,13 @@ "outputs": [], "source": [ "# create a connection based on provided credentials\n", - "connection = Connection(\n", - " engine_url, database_name, username, password, api_endpoint=api_endpoint\n", + "connection = connect(\n", + " engine_url=engine_url,\n", + " engine_name=engine_name,\n", + " database=database_name,\n", + " username=username,\n", + " password=password,\n", + " api_endpoint=api_endpoint,\n", ")\n", "\n", "# create a cursor for connection\n", diff --git a/tests/integration/dbapi/conftest.py b/tests/integration/dbapi/conftest.py index 5e35b317606..362b957d704 100644 --- a/tests/integration/dbapi/conftest.py +++ b/tests/integration/dbapi/conftest.py @@ -5,13 +5,14 @@ from pytest import fixture -from firebolt.db import ARRAY, Connection +from firebolt.db import ARRAY, Connection, connect from firebolt.db._types import ColType from firebolt.db.cursor import Column LOGGER = getLogger(__name__) ENGINE_URL_ENV = "ENGINE_URL" +ENGINE_NAME_ENV = "ENGINE_NAME" DATABASE_NAME_ENV = "DATABASE_NAME" USERNAME_ENV = "USERNAME" PASSWORD_ENV = "PASSWORD" @@ -29,6 +30,11 @@ def engine_url() -> str: return must_env(ENGINE_URL_ENV) +@fixture(scope="session") +def engine_name() -> str: + return must_env(ENGINE_NAME_ENV) + + @fixture(scope="session") def database_name() -> str: return must_env(DATABASE_NAME_ENV) @@ -53,8 +59,29 @@ def api_endpoint() -> str: def connection( engine_url: str, database_name: str, username: str, password: str, api_endpoint: str ) -> Connection: - return Connection( - engine_url, database_name, username, password, api_endpoint=api_endpoint + return connect( + engine_url=engine_url, + database=database_name, + username=username, + password=password, + api_endpoint=api_endpoint, + ) + + +@fixture +def connection_engine_name( + engine_name: str, + database_name: str, + username: str, + password: str, + api_endpoint: str, +) -> Connection: + return connect( + engine_name=engine_name, + database=database_name, + username=username, + password=password, + api_endpoint=api_endpoint, ) diff --git a/tests/integration/dbapi/test_errors.py b/tests/integration/dbapi/test_errors.py index 39bce4f7f27..f50cb9114a4 100644 --- a/tests/integration/dbapi/test_errors.py +++ b/tests/integration/dbapi/test_errors.py @@ -6,15 +6,19 @@ OperationalError, ProgrammingError, ) -from firebolt.db import Connection +from firebolt.db import Connection, connect def test_invalid_credentials( engine_url: str, database_name: str, username: str, password: str, api_endpoint: str ) -> None: """Connection properly reacts to invalid credentials error""" - connection = Connection( - engine_url, database_name, username + "_", password + "_", api_endpoint + connection = connect( + engine_url=engine_url, + database=database_name, + username=username + "_", + password=password + "_", + api_endpoint=api_endpoint, ) with raises(AuthenticationError) as exc_info: connection.cursor().execute("show tables") @@ -24,12 +28,35 @@ def test_invalid_credentials( ), "Invalid authentication error message" -def test_engine_not_exists( +def test_engine_url_not_exists( engine_url: str, database_name: str, username: str, password: str, api_endpoint: str ) -> None: """Connection properly reacts to invalid engine url error""" - connection = Connection( - engine_url + "_", database_name, username, password, api_endpoint + connection = connect( + engine_url=engine_url + "_", + database=database_name, + username=username, + password=password, + api_endpoint=api_endpoint, + ) + with raises(ConnectError): + connection.cursor().execute("show tables") + + +def test_engine_name_not_exists( + engine_name: str, + database_name: str, + username: str, + password: str, + api_endpoint: str, +) -> None: + """Connection properly reacts to invalid engine name error""" + connection = connect( + engine_url=engine_name + "_________", + database=database_name, + username=username, + password=password, + api_endpoint=api_endpoint, ) with raises(ConnectError): connection.cursor().execute("show tables") @@ -40,7 +67,13 @@ def test_database_not_exists( ) -> None: """Connection properly reacts to invalid database error""" new_db_name = database_name + "_" - connection = Connection(engine_url, new_db_name, username, password, api_endpoint) + connection = connect( + engine_url=engine_url, + database=new_db_name, + username=username, + password=password, + api_endpoint=api_endpoint, + ) with raises(ProgrammingError) as exc_info: connection.cursor().execute("show tables") @@ -50,6 +83,7 @@ def test_database_not_exists( def test_sql_error(connection: Connection) -> None: + """Connection properly reacts to sql execution error""" with connection.cursor() as c: with raises(OperationalError) as exc_info: c.execute("select ]") diff --git a/tests/integration/dbapi/test_queries.py b/tests/integration/dbapi/test_queries.py index 2eb8f6d1e6f..4eaa356bacb 100644 --- a/tests/integration/dbapi/test_queries.py +++ b/tests/integration/dbapi/test_queries.py @@ -14,6 +14,21 @@ def assert_deep_eq(got: Any, expected: Any, msg: str) -> bool: ), f"{msg}: {got}(got) != {expected}(expected)" +def test_connect_engine_name( + connection_engine_name: Connection, + all_types_query: str, + all_types_query_description: List[Column], + all_types_query_response: List[ColType], +) -> None: + """Connecting with engine name is handled properly.""" + test_select( + connection_engine_name, + all_types_query, + all_types_query_description, + all_types_query_response, + ) + + def test_select( connection: Connection, all_types_query: str, diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 6c7af704fb1..61238d2cba2 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -150,3 +150,87 @@ def provider_url(settings: Settings) -> str: @pytest.fixture def db_name() -> str: return "database" + + +@pytest.fixture +def account_id_url(settings: Settings) -> str: + return f"https://{settings.server}/iam/v2/account" + + +@pytest.fixture +def account_id_callback(account_id: str, account_id_url: str) -> Callable: + def do_mock( + request: httpx.Request = None, + **kwargs, + ) -> Response: + assert request.url == account_id_url + return to_response( + status_code=httpx.codes.OK, json={"account": {"id": account_id}} + ) + + return do_mock + + +@pytest.fixture +def engine_id() -> str: + return "engine_id" + + +@pytest.fixture +def get_engine_url(settings: Settings, account_id: str, engine_id: str) -> str: + return ( + f"https://{settings.server}/core/v1/accounts/{account_id}/engines/{engine_id}" + ) + + +@pytest.fixture +def get_engine_callback( + get_engine_url: str, engine_id: str, settings: Settings +) -> Callable: + def do_mock( + request: httpx.Request = None, + **kwargs, + ) -> Response: + assert request.url == get_engine_url + return to_response( + status_code=httpx.codes.OK, + json={ + "engine": { + "name": "name", + "compute_region_id": { + "provider_id": "provider", + "region_id": "region", + }, + "settings": { + "preset": "", + "auto_stop_delay_duration": "1s", + "minimum_logging_level": "", + "is_read_only": False, + "warm_up": "", + }, + "endpoint": f"https://{settings.server}", + } + }, + ) + + return do_mock + + +@pytest.fixture +def get_providers_url(settings: Settings, account_id: str, engine_id: str) -> str: + return f"https://{settings.server}/compute/v1/providers" + + +@pytest.fixture +def get_providers_callback(get_providers_url: str, provider: Provider) -> Callable: + def do_mock( + request: httpx.Request = None, + **kwargs, + ) -> Response: + assert request.url == get_providers_url + return to_response( + status_code=httpx.codes.OK, + json=list_to_paginated_response([provider]), + ) + + return do_mock diff --git a/tests/unit/db/conftest.py b/tests/unit/db/conftest.py index 870573eef02..fdff4b6131e 100644 --- a/tests/unit/db/conftest.py +++ b/tests/unit/db/conftest.py @@ -6,7 +6,7 @@ from pytest_httpx import to_response from firebolt.common.settings import Settings -from firebolt.db import ARRAY, Connection, Cursor +from firebolt.db import ARRAY, Connection, Cursor, connect from firebolt.db.cursor import JSON_OUTPUT_FORMAT, ColType, Column QUERY_ROW_COUNT: int = 10 @@ -139,8 +139,12 @@ def query_url(settings: Settings, db_name: str) -> str: @fixture def connection(settings: Settings, db_name: str) -> Connection: - return Connection( - f"https://{settings.server}", db_name, "u", "p", api_endpoint=settings.server + return connect( + engine_url=settings.server, + database=db_name, + username="u", + password="p", + api_endpoint=settings.server, ) diff --git a/tests/unit/db/test_connection.py b/tests/unit/db/test_connection.py index 8420d6060a3..f9dc409d292 100644 --- a/tests/unit/db/test_connection.py +++ b/tests/unit/db/test_connection.py @@ -1,11 +1,12 @@ from typing import Callable, List +from httpx import codes from pytest import raises from pytest_httpx import HTTPXMock -from firebolt.common.exception import ConnectionClosedError +from firebolt.common.exception import ConnectionClosedError, InterfaceError from firebolt.common.settings import Settings -from firebolt.db import Connection +from firebolt.db import Connection, connect from firebolt.db._types import ColType @@ -52,11 +53,11 @@ def test_cursor_initialized( httpx_mock.add_callback(query_callback, url=query_url) for url in (settings.server, f"https://{settings.server}"): - connection = Connection( - url, - db_name, - "u", - "p", + connection = connect( + engine_url=url, + database=db_name, + username="u", + password="p", api_endpoint=settings.server, ) @@ -70,3 +71,86 @@ def test_cursor_initialized( assert ( cursor not in connection._cursors ), "Cursor wasn't removed from connection after close" + + +def test_connect_empty_parameters(): + params = ("database", "username", "password") + kwargs = {"engine_url": "engine_url", **{p: p for p in params}} + + for param in params: + with raises(InterfaceError) as exc_info: + kwargs = { + "engine_url": "engine_url", + **{p: p for p in params if p != param}, + } + connect(**kwargs) + assert str(exc_info.value) == f"{param} is required to connect." + + +def test_connect_engine_name( + settings: Settings, + db_name: str, + httpx_mock: HTTPXMock, + auth_callback: Callable, + auth_url: str, + query_callback: Callable, + query_url: str, + account_id_url: str, + account_id_callback: Callable, + engine_id: str, + get_engine_url: str, + get_engine_callback: Callable, + get_providers_url: str, + get_providers_callback: Callable, + python_query_data: List[List[ColType]], + account_id: str, +): + """connect properly handles engine_name""" + + with raises(InterfaceError) as exc_info: + connect( + engine_url="engine_url", + engine_name="engine_name", + database="db", + username="username", + password="password", + ) + assert str(exc_info.value).startswith( + "Both engine_name and engine_url are provided" + ) + + with raises(InterfaceError) as exc_info: + connect( + database="db", + username="username", + password="password", + ) + assert str(exc_info.value).startswith( + "Neither engine_name nor engine_url are provided" + ) + + httpx_mock.add_callback(auth_callback, url=auth_url) + httpx_mock.add_callback(query_callback, url=query_url) + httpx_mock.add_callback(account_id_callback, url=account_id_url) + httpx_mock.add_callback(get_engine_callback, url=get_engine_url) + httpx_mock.add_callback(get_providers_callback, url=get_providers_url) + + engine_name = settings.server.split(".")[0] + + # Mock engine id lookup by name + httpx_mock.add_response( + url=f"https://{settings.server}/core/v1/account/engines:getIdByName?" + f"engine_name={engine_name}", + status_code=codes.OK, + json={"engine_id": {"engine_id": engine_id}}, + ) + + cursor = connect( + engine_name=engine_name, + database=db_name, + username="u", + password="p", + api_endpoint=settings.server, + ).cursor() + + assert cursor.execute("select*") == len(python_query_data)