Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/firebolt/async_db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,14 +308,15 @@ def __init__(
# Override tcp keepalive settings for connection
transport = AsyncHTTPTransport()
transport._pool._network_backend = OverriddenHttpBackend()
connector_versions = additional_parameters.get("connector_versions", [])
user_drivers = additional_parameters.get("user_drivers", [])
user_clients = additional_parameters.get("user_clients", [])
self._client = AsyncClient(
auth=auth,
base_url=engine_url,
api_endpoint=api_endpoint,
timeout=Timeout(DEFAULT_TIMEOUT_SECONDS, read=None),
transport=transport,
headers={"User-Agent": get_user_agent_header(connector_versions)},
headers={"User-Agent": get_user_agent_header(user_drivers, user_clients)},
)
self.api_endpoint = api_endpoint
self.engine_url = engine_url
Expand Down
68 changes: 45 additions & 23 deletions src/firebolt/utils/usage_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,14 @@ class ConnectorVersions(BaseModel):
Verify correct parameter types
"""

versions: List[Tuple[str, str]]
clients: List[Tuple[str, str]]
drivers: List[Tuple[str, str]]


logger = logging.getLogger(__name__)


CONNECTOR_MAP = [
(
"DBT",
"open",
Path("dbt/adapters/firebolt/connections.py"),
"dbt.adapters.firebolt",
),
CLIENT_MAP = [
(
"Airflow",
"get_conn",
Expand All @@ -54,10 +49,19 @@ class ConnectorVersions(BaseModel):
Path("source_firebolt/source.py"),
"",
),
("SQLAlchemy", "connect", Path("sqlalchemy/engine/default.py"), "firebolt_db"),
("FireboltCLI", "create_connection", Path("firebolt_cli/utils.py"), "firebolt_cli"),
]

DRIVER_MAP = [
(
"DBT",
"open",
Path("dbt/adapters/firebolt/connections.py"),
"dbt.adapters.firebolt",
),
("SQLAlchemy", "connect", Path("sqlalchemy/engine/default.py"), "firebolt_db"),
]


def _os_compare(file: Path, expected: Path) -> bool:
"""
Expand Down Expand Up @@ -94,7 +98,9 @@ def get_sdk_properties() -> Tuple[str, str, str, str]:
return (py_version, sdk_version, os_version, ciso)


def detect_connectors() -> Dict[str, str]:
def detect_connectors(
connector_map: List[Tuple[str, str, Path, str]]
) -> Dict[str, str]:
"""
Detect which connectors are running the code by parsing the stack.
Exceptions are ignored since this is intended for logging only.
Expand All @@ -103,7 +109,7 @@ def detect_connectors() -> Dict[str, str]:
stack = inspect.stack()
for f in stack:
try:
for name, func, path, version_path in CONNECTOR_MAP:
for name, func, path, version_path in connector_map:
if f.function == func and _os_compare(Path(f.filename), path):
if version_path:
m = import_module(version_path)
Expand All @@ -120,7 +126,7 @@ def detect_connectors() -> Dict[str, str]:
return connectors


def format_as_user_agent(connectors: Dict[str, str]) -> str:
def format_as_user_agent(drivers: Dict[str, str], clients: Dict[str, str]) -> str:
"""
Return a representation of a stored tracking data as a user-agent header.

Expand All @@ -132,28 +138,44 @@ def format_as_user_agent(connectors: Dict[str, str]) -> str:
"""
py, sdk, os, ciso = get_sdk_properties()
sdk_format = f"PythonSDK/{sdk} (Python {py}; {os}; {ciso})"
connector_format = "".join(
[f" {connector}/{version}" for connector, version in connectors.items()]
driver_format = "".join(
[f" {connector}/{version}" for connector, version in drivers.items()]
)
client_format = "".join(
[f"{connector}/{version} " for connector, version in clients.items()]
)
return sdk_format + connector_format
return client_format + sdk_format + driver_format


def get_user_agent_header(
connector_versions: Optional[List[Tuple[str, str]]] = []
user_drivers: Optional[List[Tuple[str, str]]] = [],
user_clients: Optional[List[Tuple[str, str]]] = [],
) -> str:
"""
Return a user agent header with connector stack and system information.

Args:
connector_versions(Optional): User-supplied list of tuples of all connectors
and their versions intended for tracking.
user_drivers(Optional): User-supplied list of tuples of all drivers
and their versions intended for tracking. Driver is a programmatic
module that facilitates interaction between a clients and underlying
database.
user_clients(Optional): User-supplied list of tuples of all clients
and their versions intended for tracking. Client is a user-facing
module or application that allows interaction with the database
via drivers or directly.

Returns:
String representation of a user-agent tracking information
"""
connectors = detect_connectors()
logger.debug("Detected running from packages: %s", str(connectors))
drivers = detect_connectors(DRIVER_MAP)
clients = detect_connectors(CLIENT_MAP)
logger.debug(
"Detected running with drivers: %s and clients %s ", str(drivers), str(clients)
)
# Override auto-detected connectors with info provided manually
for name, version in ConnectorVersions(versions=connector_versions).versions:
connectors[name] = version
return format_as_user_agent(connectors)
versions = ConnectorVersions(clients=user_clients, drivers=user_drivers)
for name, version in versions.clients:
clients[name] = version
for name, version in versions.drivers:
drivers[name] = version
return format_as_user_agent(drivers, clients)
13 changes: 8 additions & 5 deletions tests/unit/async_db/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,11 +400,11 @@ async def test_connect_with_user_agent(
access_token: str,
) -> None:
with patch("firebolt.async_db.connection.get_user_agent_header") as ut:
ut.return_value = "MyConnector/1.0"
ut.return_value = "MyConnector/1.0 DriverA/1.1"
httpx_mock.add_callback(
query_callback,
url=query_url,
match_headers={"User-Agent": "MyConnector/1.0"},
match_headers={"User-Agent": "MyConnector/1.0 DriverA/1.1"},
)

async with await connect(
Expand All @@ -413,10 +413,13 @@ async def test_connect_with_user_agent(
engine_url=settings.server,
account_name=settings.account_name,
api_endpoint=settings.server,
additional_parameters={"connector_versions": [("MyConnector", "1.0")]},
additional_parameters={
"user_clients": [("MyConnector", "1.0")],
"user_drivers": [("DriverA", "1.1")],
},
) as connection:
await connection.cursor().execute("select*")
ut.assert_called_once_with([("MyConnector", "1.0")])
ut.assert_called_once_with([("DriverA", "1.1")], [("MyConnector", "1.0")])


@mark.asyncio
Expand All @@ -442,4 +445,4 @@ async def test_connect_no_user_agent(
api_endpoint=settings.server,
) as connection:
await connection.cursor().execute("select*")
ut.assert_called_once_with([])
ut.assert_called_once_with([], [])
61 changes: 47 additions & 14 deletions tests/unit/utils/test_usage_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from pytest import mark, raises

from firebolt.utils.usage_tracker import (
CLIENT_MAP,
DRIVER_MAP,
detect_connectors,
get_sdk_properties,
get_user_agent_header,
Expand Down Expand Up @@ -48,13 +50,14 @@ def test_get_sdk_properties():
},
)
@mark.parametrize(
"stack,expected",
"stack,map,expected",
[
(
[
StackItem("create_connection", "dir1/dir2/firebolt_cli/utils.py"),
StackItem("dummy", "dummy.py"),
],
CLIENT_MAP,
{"FireboltCLI": "0.1.1"},
),
(
Expand All @@ -64,22 +67,27 @@ def test_get_sdk_properties():
"my_documents/some_other_dir/firebolt_cli/utils.py",
)
],
CLIENT_MAP,
{"FireboltCLI": "0.1.1"},
),
(
[StackItem("connect", "sqlalchemy/engine/default.py")],
DRIVER_MAP,
{"SQLAlchemy": "0.1.2"},
),
(
[StackItem("establish_connection", "source_firebolt/source.py")],
CLIENT_MAP,
{"AirbyteSource": ""},
),
(
[StackItem("establish_async_connection", "source_firebolt/source.py")],
CLIENT_MAP,
{"AirbyteSource": ""},
),
(
[StackItem("establish_connection", "destination_firebolt/destination.py")],
CLIENT_MAP,
{"AirbyteDestination": ""},
),
(
Expand All @@ -88,60 +96,85 @@ def test_get_sdk_properties():
"establish_async_connection", "destination_firebolt/destination.py"
)
],
CLIENT_MAP,
{"AirbyteDestination": ""},
),
(
[StackItem("get_conn", "firebolt_provider/hooks/firebolt.py")],
CLIENT_MAP,
{"Airflow": "0.1.3"},
),
([StackItem("open", "dbt/adapters/firebolt/connections.py")], {"DBT": "0.1.4"}),
(
[StackItem("open", "dbt/adapters/firebolt/connections.py")],
DRIVER_MAP,
{"DBT": "0.1.4"},
),
(
[StackItem("open", "dbt/adapters/firebolt/connections.py")],
CLIENT_MAP,
{},
),
],
)
def test_detect_connectors(stack, expected):
def test_detect_connectors(stack, map, expected):
with patch(
"firebolt.utils.usage_tracker.inspect.stack", MagicMock(return_value=stack)
):
assert detect_connectors() == expected
assert detect_connectors(map) == expected


@mark.parametrize(
"connectors,expected_string",
"drivers,clients,expected_string",
[
([], "PythonSDK/2 (Python 1; Win; ciso)"),
([], [], "PythonSDK/2 (Python 1; Win; ciso)"),
(
[("ConnectorA", "0.1.1")],
[],
"PythonSDK/2 (Python 1; Win; ciso) ConnectorA/0.1.1",
),
(
(("ConnectorA", "0.1.1"), ("ConnectorB", "0.2.0")),
(),
"PythonSDK/2 (Python 1; Win; ciso) ConnectorA/0.1.1 ConnectorB/0.2.0",
),
(
[("ConnectorA", "0.1.1"), ("ConnectorB", "0.2.0")],
[],
"PythonSDK/2 (Python 1; Win; ciso) ConnectorA/0.1.1 ConnectorB/0.2.0",
),
(
[("ConnectorA", "0.1.1"), ("ConnectorB", "0.2.0")],
[("ClientA", "1.0.1")],
"ClientA/1.0.1 PythonSDK/2 (Python 1; Win; ciso) ConnectorA/0.1.1 ConnectorB/0.2.0",
),
],
)
@patch(
"firebolt.utils.usage_tracker.get_sdk_properties",
MagicMock(return_value=("1", "2", "Win", "ciso")),
)
def test_user_agent(connectors, expected_string):
assert get_user_agent_header(connectors) == expected_string
def test_user_agent(drivers, clients, expected_string):
assert get_user_agent_header(drivers, clients) == expected_string


@mark.parametrize(
"connectors",
"drivers,clients",
[
([1]),
((("Con1", "v1.1"), ("Con2"))),
(("Connector1.1")),
([1], []),
((("Con1", "v1.1"), ("Con2")), []),
(("Connector1.1"), ()),
(
[],
[1],
),
([], (("Con1", "v1.1"), ("Con2"))),
((), ("Connector1.1")),
],
)
@patch(
"firebolt.utils.usage_tracker.get_sdk_properties",
MagicMock(return_value=("1", "2", "Win", "ciso")),
)
def test_incorrect_user_agent(connectors):
def test_incorrect_user_agent(drivers, clients):
with raises(ValidationError):
get_user_agent_header(connectors)
get_user_agent_header(drivers, clients)