diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index 85d73e9010a..a6bb94ddfeb 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -4,7 +4,7 @@ import socket from json import JSONDecodeError from types import TracebackType -from typing import Any, Callable, List, Optional, Type +from typing import Any, Callable, Dict, List, Optional, Type from httpcore.backends.auto import AutoBackend from httpcore.backends.base import AsyncNetworkStream @@ -24,6 +24,7 @@ ACCOUNT_ENGINE_URL, ACCOUNT_ENGINE_URL_BY_DATABASE_NAME, ) +from firebolt.utils.usage_tracker import get_user_agent_header from firebolt.utils.util import fix_url_schema DEFAULT_TIMEOUT_SECONDS: int = 5 @@ -166,6 +167,7 @@ async def connect_inner( account_name: Optional[str] = None, api_endpoint: str = DEFAULT_API_URL, use_token_cache: bool = True, + additional_parameters: Dict[str, Any] = {}, ) -> Connection: """Connect to Firebolt database. @@ -183,6 +185,8 @@ async def connect_inner( api_endpoint (str): Firebolt API endpoint. Used for authentication. use_token_cache (bool): Cached authentication token in filesystem. Default: True + additional_parameters (Optional[Dict]): Dictionary of less widely-used + arguments for connection. Note: Providing both `engine_name` and `engine_url` would result in an error. @@ -238,7 +242,9 @@ async def connect_inner( assert engine_url is not None engine_url = fix_url_schema(engine_url) - return connection_class(engine_url, database, auth, api_endpoint) + return connection_class( + engine_url, database, auth, api_endpoint, additional_parameters + ) return connect_inner @@ -297,17 +303,19 @@ def __init__( database: str, auth: Auth, api_endpoint: str = DEFAULT_API_URL, + additional_parameters: Dict[str, Any] = {}, ): # Override tcp keepalive settings for connection transport = AsyncHTTPTransport() transport._pool._network_backend = OverriddenHttpBackend() - + connector_versions = additional_parameters.get("connector_versions", []) self._client = AsyncClient( auth=auth, base_url=engine_url, api_endpoint=api_endpoint, timeout=Timeout(DEFAULT_TIMEOUT_SECONDS, read=None), transport=transport, + headers={"User-Agent": get_user_agent_header(connector_versions)}, ) self.api_endpoint = api_endpoint self.engine_url = engine_url @@ -372,6 +380,9 @@ class Connection(BaseConnection): username: Firebolt account username password: Firebolt account password api_endpoint: Optional. Firebolt API endpoint. Used for authentication. + connector_versions: Optional. Tuple of connector name and version or + list of tuples of your connector stack. Useful for tracking custom + connector usage. Note: Firebolt currenly doesn't support transactions diff --git a/src/firebolt/utils/usage_tracker.py b/src/firebolt/utils/usage_tracker.py new file mode 100644 index 00000000000..2842e73d6f3 --- /dev/null +++ b/src/firebolt/utils/usage_tracker.py @@ -0,0 +1,159 @@ +import inspect +import logging +from importlib import import_module +from pathlib import Path +from platform import python_version, release, system +from sys import modules +from typing import Dict, List, Optional, Tuple + +from pydantic import BaseModel + +from firebolt import __version__ + + +class ConnectorVersions(BaseModel): + """ + Verify correct parameter types + """ + + versions: List[Tuple[str, str]] + + +logger = logging.getLogger(__name__) + + +CONNECTOR_MAP = [ + ( + "DBT", + "open", + Path("dbt/adapters/firebolt/connections.py"), + "dbt.adapters.firebolt", + ), + ( + "Airflow", + "get_conn", + Path("firebolt_provider/hooks/firebolt.py"), + "firebolt_provider", + ), + ( + "AirbyteDestination", + "establish_connection", + Path("destination_firebolt/destination.py"), + "", + ), + ( + "AirbyteDestination", + "establish_async_connection", + Path("destination_firebolt/destination.py"), + "", + ), + ("AirbyteSource", "establish_connection", Path("source_firebolt/source.py"), ""), + ( + "AirbyteSource", + "establish_async_connection", + Path("source_firebolt/source.py"), + "", + ), + ("SQLAlchemy", "connect", Path("sqlalchemy/engine/default.py"), "firebolt_db"), + ("FireboltCLI", "create_connection", Path("firebolt_cli/utils.py"), "firebolt_cli"), +] + + +def _os_compare(file: Path, expected: Path) -> bool: + """ + System-independent path comparison. + + Args: + file: file path to check against + expected: expected file path + + Returns: + True if file ends with path + """ + return file.parts[-len(expected.parts) :] == expected.parts + + +def get_sdk_properties() -> Tuple[str, str, str, str]: + """ + Detect Python, OS and SDK versions. + + Returns: + Python version, SDK version, OS name and "ciso" if imported + """ + py_version = python_version() + sdk_version = __version__ + os_version = f"{system()} {release()}" + ciso = "ciso8601" if "ciso8601" in modules.keys() else "" + logger.debug( + "Python %s detected. SDK %s OS %s %s", + py_version, + sdk_version, + os_version, + ciso, + ) + return (py_version, sdk_version, os_version, ciso) + + +def detect_connectors() -> Dict[str, str]: + """ + Detect which connectors are running the code by parsing the stack. + Exceptions are ignored since this is intended for logging only. + """ + connectors: Dict[str, str] = {} + stack = inspect.stack() + for f in stack: + try: + for name, func, path, version_path in CONNECTOR_MAP: + if f.function == func and _os_compare(Path(f.filename), path): + if version_path: + m = import_module(version_path) + connectors[name] = m.__version__ # type: ignore + else: + # Some connectors don't have versions specified + connectors[name] = "" + # No need to carry on if connector is detected + break + except Exception: + logger.debug( + "Failed to extract version from %s in %s", f.function, f.filename + ) + return connectors + + +def format_as_user_agent(connectors: Dict[str, str]) -> str: + """ + Return a representation of a stored tracking data as a user-agent header. + + Args: + connectors: Dictionary of connector to version mappings + + Returns: + String of the current detected connector stack. + """ + py, sdk, os, ciso = get_sdk_properties() + sdk_format = f"PythonSDK/{sdk} (Python {py}; {os}; {ciso})" + connector_format = "".join( + [f" {connector}/{version}" for connector, version in connectors.items()] + ) + return sdk_format + connector_format + + +def get_user_agent_header( + connector_versions: Optional[List[Tuple[str, str]]] = [] +) -> str: + """ + Return a user agent header with connector stack and system information. + + Args: + connector_versions(Optional): User-supplied list of tuples of all connectors + and their versions intended for tracking. + + Returns: + String representation of a user-agent tracking information + """ + connectors = detect_connectors() + logger.debug("Detected running from packages: %s", str(connectors)) + # Override auto-detected connectors with info provided manually + for name, version in ConnectorVersions(versions=connector_versions).versions: + connectors[name] = version + return format_as_user_agent(connectors) diff --git a/tests/integration/utils/sample_usage.model b/tests/integration/utils/sample_usage.model new file mode 100644 index 00000000000..e30fbb1d904 --- /dev/null +++ b/tests/integration/utils/sample_usage.model @@ -0,0 +1,15 @@ +import sys + +# Hack to avoid detecting current file as firebolt module +old_path = sys.path +sys.path = sys.path[1:] +from firebolt.utils.usage_tracker import get_user_agent_header + +# Back to old path for detection to work properly +sys.path = old_path + + +def {function_name}(): + print(get_user_agent_header()) + +{function_name}() diff --git a/tests/integration/utils/test_usage_tracker.py b/tests/integration/utils/test_usage_tracker.py new file mode 100644 index 00000000000..6ba9261e4ef --- /dev/null +++ b/tests/integration/utils/test_usage_tracker.py @@ -0,0 +1,77 @@ +import os +from pathlib import Path +from shutil import rmtree +from subprocess import PIPE, run + +from pytest import fixture, mark + +TEST_FOLDER = "tmp_test_code/" +TEST_SCRIPT_MODEL = "tests/integration/utils/sample_usage.model" + + +MOCK_MODULES = [ + "firebolt_cli/firebolt_cli.py", + "sqlalchemy/engine/firebolt_db.py", + "firebolt_provider/hooks/firebolt_provider.py", + "dbt/adapters/firebolt/dbt/adapters/firebolt.py", +] + + +@fixture(scope="module", autouse=True) +def create_cli_mock(): + for i, file in enumerate(MOCK_MODULES): + os.makedirs(os.path.dirname(f"{TEST_FOLDER}{file}")) + with open(f"{TEST_FOLDER}{file}", "w") as f: + f.write(f"__version__ = '1.0.{i}'") + # Additional setup for proper dbt import + Path(f"{TEST_FOLDER}dbt/adapters/firebolt/dbt/__init__.py").touch() + Path(f"{TEST_FOLDER}/dbt/adapters/firebolt/dbt/adapters/__init__.py").touch() + yield + rmtree(TEST_FOLDER) + + +@fixture(scope="module") +def test_model(): + with open(TEST_SCRIPT_MODEL) as f: + return f.read() + + +def create_test_file(code: str, function_name: str, file_path: str): + code = code.format(function_name=function_name) + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, "w") as f: + f.write(code) + + +@mark.parametrize( + "function,path,expected", + [ + ("create_connection", "firebolt_cli/utils.py", "FireboltCLI/1.0.0"), + ("connect", "sqlalchemy/engine/default.py", "SQLAlchemy/1.0.1"), + ("establish_connection", "source_firebolt/source.py", "AirbyteSource/"), + ("establish_async_connection", "source_firebolt/source.py", "AirbyteSource/"), + ( + "establish_connection", + "destination_firebolt/destination.py", + "AirbyteDestination/", + ), + ( + "establish_async_connection", + "destination_firebolt/destination.py", + "AirbyteDestination/", + ), + ("get_conn", "firebolt_provider/hooks/firebolt.py", "Airflow/1.0.2"), + ("open", "dbt/adapters/firebolt/connections.py", "DBT/1.0.3"), + ], +) +def test_usage_detection(function, path, expected, test_model): + test_path = TEST_FOLDER + path + create_test_file(test_model, function, test_path) + result = run( + ["python3", test_path], + stdout=PIPE, + stderr=PIPE, + env={"PYTHONPATH": os.getenv("PYTHONPATH", ""), "PATH": os.getenv("PATH", "")}, + ) + assert not result.stderr + assert expected in result.stdout.decode("utf-8") diff --git a/tests/unit/async_db/test_connection.py b/tests/unit/async_db/test_connection.py index f56803201b3..15d8f904b7a 100644 --- a/tests/unit/async_db/test_connection.py +++ b/tests/unit/async_db/test_connection.py @@ -1,13 +1,14 @@ from re import Pattern from typing import Callable, List +from unittest.mock import patch from httpx import codes from pyfakefs.fake_filesystem_unittest import Patcher from pytest import mark, raises from pytest_httpx import HTTPXMock -from firebolt.async_db import Connection, connect from firebolt.async_db._types import ColType +from firebolt.async_db.connection import Connection, connect from firebolt.client.auth import Auth, Token, UsernamePassword from firebolt.common.settings import Settings from firebolt.utils.exception import ( @@ -387,3 +388,58 @@ async def test_connect_account_name( api_endpoint=settings.server, ): pass + + +@mark.asyncio +async def test_connect_with_user_agent( + httpx_mock: HTTPXMock, + settings: Settings, + db_name: str, + query_callback: Callable, + query_url: str, + access_token: str, +) -> None: + with patch("firebolt.async_db.connection.get_user_agent_header") as ut: + ut.return_value = "MyConnector/1.0" + httpx_mock.add_callback( + query_callback, + url=query_url, + match_headers={"User-Agent": "MyConnector/1.0"}, + ) + + async with await connect( + auth=Token(access_token), + database=db_name, + engine_url=settings.server, + account_name=settings.account_name, + api_endpoint=settings.server, + additional_parameters={"connector_versions": [("MyConnector", "1.0")]}, + ) as connection: + await connection.cursor().execute("select*") + ut.assert_called_once_with([("MyConnector", "1.0")]) + + +@mark.asyncio +async def test_connect_no_user_agent( + httpx_mock: HTTPXMock, + settings: Settings, + db_name: str, + query_callback: Callable, + query_url: str, + access_token: str, +) -> None: + with patch("firebolt.async_db.connection.get_user_agent_header") as ut: + ut.return_value = "Python/3.0" + httpx_mock.add_callback( + query_callback, url=query_url, match_headers={"User-Agent": "Python/3.0"} + ) + + async with await connect( + auth=Token(access_token), + database=db_name, + engine_url=settings.server, + account_name=settings.account_name, + api_endpoint=settings.server, + ) as connection: + await connection.cursor().execute("select*") + ut.assert_called_once_with([]) diff --git a/tests/unit/db/test_connection.py b/tests/unit/db/test_connection.py index 5d44a26e799..f8ce4d34881 100644 --- a/tests/unit/db/test_connection.py +++ b/tests/unit/db/test_connection.py @@ -1,3 +1,5 @@ +import gc +import warnings from re import Pattern from typing import Callable, List @@ -232,12 +234,22 @@ def test_connection_unclosed_warnings(): c = Connection("", "", None, "") with warns(UserWarning) as winfo: del c + gc.collect() - assert "Unclosed" in str( - winfo.list[0].message + assert any( + "Unclosed" in str(warning.message) for warning in winfo.list ), "Invalid unclosed connection warning" +def test_connection_no_warnings(): + c = Connection("", "", None, "") + c.close() + with warnings.catch_warnings(): + warnings.simplefilter("error") + del c + gc.collect() + + def test_connection_commit(connection: Connection): # nothing happens connection.commit() diff --git a/tests/unit/utils/test_usage_tracker.py b/tests/unit/utils/test_usage_tracker.py new file mode 100644 index 00000000000..60c8da620b6 --- /dev/null +++ b/tests/unit/utils/test_usage_tracker.py @@ -0,0 +1,147 @@ +from collections import namedtuple +from unittest.mock import MagicMock, patch + +from pydantic import ValidationError +from pytest import mark, raises + +from firebolt.utils.usage_tracker import ( + detect_connectors, + get_sdk_properties, + get_user_agent_header, +) + + +@patch("firebolt.utils.usage_tracker.python_version", MagicMock(return_value="3.10.1")) +@patch("firebolt.utils.usage_tracker.release", MagicMock(return_value="2.2.1")) +@patch("firebolt.utils.usage_tracker.system", MagicMock(return_value="Linux")) +@patch("firebolt.utils.usage_tracker.__version__", "0.1.1") +def test_get_sdk_properties(): + with patch.dict("firebolt.utils.usage_tracker.modules", {}, clear=True): + assert ("3.10.1", "0.1.1", "Linux 2.2.1", "") == get_sdk_properties() + with patch.dict( + "firebolt.utils.usage_tracker.modules", {"ciso8601": ""}, clear=True + ): + assert ("3.10.1", "0.1.1", "Linux 2.2.1", "ciso8601") == get_sdk_properties() + + +StackItem = namedtuple("StackItem", "function filename") + + +@patch.dict( + "firebolt.utils.usage_tracker.modules", + {"firebolt_cli": MagicMock(__version__="0.1.1")}, +) +@patch.dict( + "firebolt.utils.usage_tracker.modules", + {"firebolt_db": MagicMock(__version__="0.1.2")}, +) +@patch.dict( + "firebolt.utils.usage_tracker.modules", + {"firebolt_provider": MagicMock(__version__="0.1.3")}, +) +@patch.dict( + "firebolt.utils.usage_tracker.modules", + { + "dbt": MagicMock(), + "dbt.adapters": MagicMock(), + "dbt.adapters.firebolt": MagicMock(__version__="0.1.4"), + }, +) +@mark.parametrize( + "stack,expected", + [ + ( + [ + StackItem("create_connection", "dir1/dir2/firebolt_cli/utils.py"), + StackItem("dummy", "dummy.py"), + ], + {"FireboltCLI": "0.1.1"}, + ), + ( + [ + StackItem( + "create_connection", + "my_documents/some_other_dir/firebolt_cli/utils.py", + ) + ], + {"FireboltCLI": "0.1.1"}, + ), + ( + [StackItem("connect", "sqlalchemy/engine/default.py")], + {"SQLAlchemy": "0.1.2"}, + ), + ( + [StackItem("establish_connection", "source_firebolt/source.py")], + {"AirbyteSource": ""}, + ), + ( + [StackItem("establish_async_connection", "source_firebolt/source.py")], + {"AirbyteSource": ""}, + ), + ( + [StackItem("establish_connection", "destination_firebolt/destination.py")], + {"AirbyteDestination": ""}, + ), + ( + [ + StackItem( + "establish_async_connection", "destination_firebolt/destination.py" + ) + ], + {"AirbyteDestination": ""}, + ), + ( + [StackItem("get_conn", "firebolt_provider/hooks/firebolt.py")], + {"Airflow": "0.1.3"}, + ), + ([StackItem("open", "dbt/adapters/firebolt/connections.py")], {"DBT": "0.1.4"}), + ], +) +def test_detect_connectors(stack, expected): + with patch( + "firebolt.utils.usage_tracker.inspect.stack", MagicMock(return_value=stack) + ): + assert detect_connectors() == expected + + +@mark.parametrize( + "connectors,expected_string", + [ + ([], "PythonSDK/2 (Python 1; Win; ciso)"), + ( + [("ConnectorA", "0.1.1")], + "PythonSDK/2 (Python 1; Win; ciso) ConnectorA/0.1.1", + ), + ( + (("ConnectorA", "0.1.1"), ("ConnectorB", "0.2.0")), + "PythonSDK/2 (Python 1; Win; ciso) ConnectorA/0.1.1 ConnectorB/0.2.0", + ), + ( + [("ConnectorA", "0.1.1"), ("ConnectorB", "0.2.0")], + "PythonSDK/2 (Python 1; Win; ciso) ConnectorA/0.1.1 ConnectorB/0.2.0", + ), + ], +) +@patch( + "firebolt.utils.usage_tracker.get_sdk_properties", + MagicMock(return_value=("1", "2", "Win", "ciso")), +) +def test_user_agent(connectors, expected_string): + assert get_user_agent_header(connectors) == expected_string + + +@mark.parametrize( + "connectors", + [ + ([1]), + ((("Con1", "v1.1"), ("Con2"))), + (("Connector1.1")), + ], +) +@patch( + "firebolt.utils.usage_tracker.get_sdk_properties", + MagicMock(return_value=("1", "2", "Win", "ciso")), +) +def test_incorrect_user_agent(connectors): + with raises(ValidationError): + get_user_agent_header(connectors)