diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index a6bb94ddfeb..1e0f4e10486 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -308,14 +308,15 @@ def __init__( # Override tcp keepalive settings for connection transport = AsyncHTTPTransport() transport._pool._network_backend = OverriddenHttpBackend() - connector_versions = additional_parameters.get("connector_versions", []) + user_drivers = additional_parameters.get("user_drivers", []) + user_clients = additional_parameters.get("user_clients", []) self._client = AsyncClient( auth=auth, base_url=engine_url, api_endpoint=api_endpoint, timeout=Timeout(DEFAULT_TIMEOUT_SECONDS, read=None), transport=transport, - headers={"User-Agent": get_user_agent_header(connector_versions)}, + headers={"User-Agent": get_user_agent_header(user_drivers, user_clients)}, ) self.api_endpoint = api_endpoint self.engine_url = engine_url diff --git a/src/firebolt/utils/usage_tracker.py b/src/firebolt/utils/usage_tracker.py index 2842e73d6f3..624ca7ef04f 100644 --- a/src/firebolt/utils/usage_tracker.py +++ b/src/firebolt/utils/usage_tracker.py @@ -16,19 +16,14 @@ class ConnectorVersions(BaseModel): Verify correct parameter types """ - versions: List[Tuple[str, str]] + clients: List[Tuple[str, str]] + drivers: List[Tuple[str, str]] logger = logging.getLogger(__name__) -CONNECTOR_MAP = [ - ( - "DBT", - "open", - Path("dbt/adapters/firebolt/connections.py"), - "dbt.adapters.firebolt", - ), +CLIENT_MAP = [ ( "Airflow", "get_conn", @@ -54,10 +49,19 @@ class ConnectorVersions(BaseModel): Path("source_firebolt/source.py"), "", ), - ("SQLAlchemy", "connect", Path("sqlalchemy/engine/default.py"), "firebolt_db"), ("FireboltCLI", "create_connection", Path("firebolt_cli/utils.py"), "firebolt_cli"), ] +DRIVER_MAP = [ + ( + "DBT", + "open", + Path("dbt/adapters/firebolt/connections.py"), + "dbt.adapters.firebolt", + ), + ("SQLAlchemy", "connect", Path("sqlalchemy/engine/default.py"), "firebolt_db"), +] + def _os_compare(file: Path, expected: Path) -> bool: """ @@ -94,7 +98,9 @@ def get_sdk_properties() -> Tuple[str, str, str, str]: return (py_version, sdk_version, os_version, ciso) -def detect_connectors() -> Dict[str, str]: +def detect_connectors( + connector_map: List[Tuple[str, str, Path, str]] +) -> Dict[str, str]: """ Detect which connectors are running the code by parsing the stack. Exceptions are ignored since this is intended for logging only. @@ -103,7 +109,7 @@ def detect_connectors() -> Dict[str, str]: stack = inspect.stack() for f in stack: try: - for name, func, path, version_path in CONNECTOR_MAP: + for name, func, path, version_path in connector_map: if f.function == func and _os_compare(Path(f.filename), path): if version_path: m = import_module(version_path) @@ -120,7 +126,7 @@ def detect_connectors() -> Dict[str, str]: return connectors -def format_as_user_agent(connectors: Dict[str, str]) -> str: +def format_as_user_agent(drivers: Dict[str, str], clients: Dict[str, str]) -> str: """ Return a representation of a stored tracking data as a user-agent header. @@ -132,28 +138,44 @@ def format_as_user_agent(connectors: Dict[str, str]) -> str: """ py, sdk, os, ciso = get_sdk_properties() sdk_format = f"PythonSDK/{sdk} (Python {py}; {os}; {ciso})" - connector_format = "".join( - [f" {connector}/{version}" for connector, version in connectors.items()] + driver_format = "".join( + [f" {connector}/{version}" for connector, version in drivers.items()] + ) + client_format = "".join( + [f"{connector}/{version} " for connector, version in clients.items()] ) - return sdk_format + connector_format + return client_format + sdk_format + driver_format def get_user_agent_header( - connector_versions: Optional[List[Tuple[str, str]]] = [] + user_drivers: Optional[List[Tuple[str, str]]] = [], + user_clients: Optional[List[Tuple[str, str]]] = [], ) -> str: """ Return a user agent header with connector stack and system information. Args: - connector_versions(Optional): User-supplied list of tuples of all connectors - and their versions intended for tracking. + user_drivers(Optional): User-supplied list of tuples of all drivers + and their versions intended for tracking. Driver is a programmatic + module that facilitates interaction between a clients and underlying + database. + user_clients(Optional): User-supplied list of tuples of all clients + and their versions intended for tracking. Client is a user-facing + module or application that allows interaction with the database + via drivers or directly. Returns: String representation of a user-agent tracking information """ - connectors = detect_connectors() - logger.debug("Detected running from packages: %s", str(connectors)) + drivers = detect_connectors(DRIVER_MAP) + clients = detect_connectors(CLIENT_MAP) + logger.debug( + "Detected running with drivers: %s and clients %s ", str(drivers), str(clients) + ) # Override auto-detected connectors with info provided manually - for name, version in ConnectorVersions(versions=connector_versions).versions: - connectors[name] = version - return format_as_user_agent(connectors) + versions = ConnectorVersions(clients=user_clients, drivers=user_drivers) + for name, version in versions.clients: + clients[name] = version + for name, version in versions.drivers: + drivers[name] = version + return format_as_user_agent(drivers, clients) diff --git a/tests/unit/async_db/test_connection.py b/tests/unit/async_db/test_connection.py index 15d8f904b7a..959a80c60c9 100644 --- a/tests/unit/async_db/test_connection.py +++ b/tests/unit/async_db/test_connection.py @@ -400,11 +400,11 @@ async def test_connect_with_user_agent( access_token: str, ) -> None: with patch("firebolt.async_db.connection.get_user_agent_header") as ut: - ut.return_value = "MyConnector/1.0" + ut.return_value = "MyConnector/1.0 DriverA/1.1" httpx_mock.add_callback( query_callback, url=query_url, - match_headers={"User-Agent": "MyConnector/1.0"}, + match_headers={"User-Agent": "MyConnector/1.0 DriverA/1.1"}, ) async with await connect( @@ -413,10 +413,13 @@ async def test_connect_with_user_agent( engine_url=settings.server, account_name=settings.account_name, api_endpoint=settings.server, - additional_parameters={"connector_versions": [("MyConnector", "1.0")]}, + additional_parameters={ + "user_clients": [("MyConnector", "1.0")], + "user_drivers": [("DriverA", "1.1")], + }, ) as connection: await connection.cursor().execute("select*") - ut.assert_called_once_with([("MyConnector", "1.0")]) + ut.assert_called_once_with([("DriverA", "1.1")], [("MyConnector", "1.0")]) @mark.asyncio @@ -442,4 +445,4 @@ async def test_connect_no_user_agent( api_endpoint=settings.server, ) as connection: await connection.cursor().execute("select*") - ut.assert_called_once_with([]) + ut.assert_called_once_with([], []) diff --git a/tests/unit/utils/test_usage_tracker.py b/tests/unit/utils/test_usage_tracker.py index 60c8da620b6..55eed561bf4 100644 --- a/tests/unit/utils/test_usage_tracker.py +++ b/tests/unit/utils/test_usage_tracker.py @@ -5,6 +5,8 @@ from pytest import mark, raises from firebolt.utils.usage_tracker import ( + CLIENT_MAP, + DRIVER_MAP, detect_connectors, get_sdk_properties, get_user_agent_header, @@ -48,13 +50,14 @@ def test_get_sdk_properties(): }, ) @mark.parametrize( - "stack,expected", + "stack,map,expected", [ ( [ StackItem("create_connection", "dir1/dir2/firebolt_cli/utils.py"), StackItem("dummy", "dummy.py"), ], + CLIENT_MAP, {"FireboltCLI": "0.1.1"}, ), ( @@ -64,22 +67,27 @@ def test_get_sdk_properties(): "my_documents/some_other_dir/firebolt_cli/utils.py", ) ], + CLIENT_MAP, {"FireboltCLI": "0.1.1"}, ), ( [StackItem("connect", "sqlalchemy/engine/default.py")], + DRIVER_MAP, {"SQLAlchemy": "0.1.2"}, ), ( [StackItem("establish_connection", "source_firebolt/source.py")], + CLIENT_MAP, {"AirbyteSource": ""}, ), ( [StackItem("establish_async_connection", "source_firebolt/source.py")], + CLIENT_MAP, {"AirbyteSource": ""}, ), ( [StackItem("establish_connection", "destination_firebolt/destination.py")], + CLIENT_MAP, {"AirbyteDestination": ""}, ), ( @@ -88,60 +96,85 @@ def test_get_sdk_properties(): "establish_async_connection", "destination_firebolt/destination.py" ) ], + CLIENT_MAP, {"AirbyteDestination": ""}, ), ( [StackItem("get_conn", "firebolt_provider/hooks/firebolt.py")], + CLIENT_MAP, {"Airflow": "0.1.3"}, ), - ([StackItem("open", "dbt/adapters/firebolt/connections.py")], {"DBT": "0.1.4"}), + ( + [StackItem("open", "dbt/adapters/firebolt/connections.py")], + DRIVER_MAP, + {"DBT": "0.1.4"}, + ), + ( + [StackItem("open", "dbt/adapters/firebolt/connections.py")], + CLIENT_MAP, + {}, + ), ], ) -def test_detect_connectors(stack, expected): +def test_detect_connectors(stack, map, expected): with patch( "firebolt.utils.usage_tracker.inspect.stack", MagicMock(return_value=stack) ): - assert detect_connectors() == expected + assert detect_connectors(map) == expected @mark.parametrize( - "connectors,expected_string", + "drivers,clients,expected_string", [ - ([], "PythonSDK/2 (Python 1; Win; ciso)"), + ([], [], "PythonSDK/2 (Python 1; Win; ciso)"), ( [("ConnectorA", "0.1.1")], + [], "PythonSDK/2 (Python 1; Win; ciso) ConnectorA/0.1.1", ), ( (("ConnectorA", "0.1.1"), ("ConnectorB", "0.2.0")), + (), "PythonSDK/2 (Python 1; Win; ciso) ConnectorA/0.1.1 ConnectorB/0.2.0", ), ( [("ConnectorA", "0.1.1"), ("ConnectorB", "0.2.0")], + [], "PythonSDK/2 (Python 1; Win; ciso) ConnectorA/0.1.1 ConnectorB/0.2.0", ), + ( + [("ConnectorA", "0.1.1"), ("ConnectorB", "0.2.0")], + [("ClientA", "1.0.1")], + "ClientA/1.0.1 PythonSDK/2 (Python 1; Win; ciso) ConnectorA/0.1.1 ConnectorB/0.2.0", + ), ], ) @patch( "firebolt.utils.usage_tracker.get_sdk_properties", MagicMock(return_value=("1", "2", "Win", "ciso")), ) -def test_user_agent(connectors, expected_string): - assert get_user_agent_header(connectors) == expected_string +def test_user_agent(drivers, clients, expected_string): + assert get_user_agent_header(drivers, clients) == expected_string @mark.parametrize( - "connectors", + "drivers,clients", [ - ([1]), - ((("Con1", "v1.1"), ("Con2"))), - (("Connector1.1")), + ([1], []), + ((("Con1", "v1.1"), ("Con2")), []), + (("Connector1.1"), ()), + ( + [], + [1], + ), + ([], (("Con1", "v1.1"), ("Con2"))), + ((), ("Connector1.1")), ], ) @patch( "firebolt.utils.usage_tracker.get_sdk_properties", MagicMock(return_value=("1", "2", "Win", "ciso")), ) -def test_incorrect_user_agent(connectors): +def test_incorrect_user_agent(drivers, clients): with raises(ValidationError): - get_user_agent_header(connectors) + get_user_agent_header(drivers, clients)