diff --git a/poetry.lock b/poetry.lock index 1a8074c2a..193efa109 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. [[package]] name = "astroid" @@ -1348,6 +1348,38 @@ files = [ [package.extras] test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"] +[[package]] +name = "pybreaker" +version = "1.2.0" +description = "Python implementation of the Circuit Breaker pattern" +optional = false +python-versions = ">=3.8" +groups = ["main"] +markers = "python_version < \"3.10\"" +files = [ + {file = "pybreaker-1.2.0-py3-none-any.whl", hash = "sha256:c3e7683e29ecb3d4421265aaea55504f1186a2fdc1f17b6b091d80d1e1eb5ede"}, + {file = "pybreaker-1.2.0.tar.gz", hash = "sha256:18707776316f93a30c1be0e4fec1f8aa5ed19d7e395a218eb2f050c8524fb2dc"}, +] + +[package.extras] +test = ["fakeredis", "mock", "pytest", "redis", "tornado", "types-mock", "types-redis"] + +[[package]] +name = "pybreaker" +version = "1.4.1" +description = "Python implementation of the Circuit Breaker pattern" +optional = false +python-versions = ">=3.9" +groups = ["main"] +markers = "python_version >= \"3.10\"" +files = [ + {file = "pybreaker-1.4.1-py3-none-any.whl", hash = "sha256:b4dab4a05195b7f2a64a6c1a6c4ba7a96534ef56ea7210e6bcb59f28897160e0"}, + {file = "pybreaker-1.4.1.tar.gz", hash = "sha256:8df2d245c73ba40c8242c56ffb4f12138fbadc23e296224740c2028ea9dc1178"}, +] + +[package.extras] +test = ["fakeredis", "mock", "pytest", "redis", "tornado", "types-mock", "types-redis"] + [[package]] name = "pycparser" version = "2.22" @@ -1858,4 +1890,4 @@ pyarrow = ["pyarrow", "pyarrow"] [metadata] lock-version = "2.1" python-versions = "^3.8.0" -content-hash = "0a3f611ef8747376f018c1df0a1ea7873368851873cc4bd3a4d51bba0bba847c" +content-hash = "56b62e3543644c91cc316b11d89025423a66daba5f36609c45bcb3eeb3ce3f54" diff --git a/pyproject.toml b/pyproject.toml index d26a71667..61c248e98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ pyarrow = [ { version = ">=18.0.0", python = ">=3.13", optional=true } ] pyjwt = "^2.0.0" +pybreaker = "^1.0.0" requests-kerberos = {version = "^0.15.0", optional = true} diff --git a/src/databricks/sql/auth/common.py b/src/databricks/sql/auth/common.py index 3e0be0d2b..a764b036d 100644 --- a/src/databricks/sql/auth/common.py +++ b/src/databricks/sql/auth/common.py @@ -51,6 +51,7 @@ def __init__( pool_connections: Optional[int] = None, pool_maxsize: Optional[int] = None, user_agent: Optional[str] = None, + telemetry_circuit_breaker_enabled: Optional[bool] = None, ): self.hostname = hostname self.access_token = access_token @@ -83,6 +84,7 @@ def __init__( self.pool_connections = pool_connections or 10 self.pool_maxsize = pool_maxsize or 20 self.user_agent = user_agent + self.telemetry_circuit_breaker_enabled = bool(telemetry_circuit_breaker_enabled) def get_effective_azure_login_app_id(hostname) -> str: diff --git a/src/databricks/sql/common/unified_http_client.py b/src/databricks/sql/common/unified_http_client.py index 96fb9cbb9..d5f7d3c8d 100644 --- a/src/databricks/sql/common/unified_http_client.py +++ b/src/databricks/sql/common/unified_http_client.py @@ -28,6 +28,42 @@ logger = logging.getLogger(__name__) +def _extract_http_status_from_max_retry_error(e: MaxRetryError) -> Optional[int]: + """ + Extract HTTP status code from MaxRetryError if available. + + urllib3 structures MaxRetryError in different ways depending on the failure scenario: + - e.reason.response.status: Most common case when retries are exhausted + - e.response.status: Alternate structure in some scenarios + + Args: + e: MaxRetryError exception from urllib3 + + Returns: + HTTP status code as int if found, None otherwise + """ + # Try primary structure: e.reason.response.status + if ( + hasattr(e, "reason") + and e.reason is not None + and hasattr(e.reason, "response") + and e.reason.response is not None + ): + http_code = getattr(e.reason.response, "status", None) + if http_code is not None: + return http_code + + # Try alternate structure: e.response.status + if ( + hasattr(e, "response") + and e.response is not None + and hasattr(e.response, "status") + ): + return e.response.status + + return None + + class UnifiedHttpClient: """ Unified HTTP client for all Databricks SQL connector HTTP operations. @@ -264,7 +300,16 @@ def request_context( yield response except MaxRetryError as e: logger.error("HTTP request failed after retries: %s", e) - raise RequestError(f"HTTP request failed: {e}") + + # Extract HTTP status code from MaxRetryError if available + http_code = _extract_http_status_from_max_retry_error(e) + + context = {} + if http_code is not None: + context["http-code"] = http_code + logger.error("HTTP request failed with status code: %d", http_code) + + raise RequestError(f"HTTP request failed: {e}", context=context) except Exception as e: logger.error("HTTP request error: %s", e) raise RequestError(f"HTTP request error: {e}") diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 3a3a6b3c5..24844d573 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -143,3 +143,24 @@ class SessionAlreadyClosedError(RequestError): class CursorAlreadyClosedError(RequestError): """Thrown if CancelOperation receives a code 404. ThriftBackend should gracefully proceed as this is expected.""" + + +class TelemetryRateLimitError(Exception): + """Raised when telemetry endpoint returns 429 or 503, indicating rate limiting or service unavailable. + This exception is used exclusively by the circuit breaker to track telemetry rate limiting events.""" + + +class TelemetryNonRateLimitError(Exception): + """Wrapper for telemetry errors that should NOT trigger circuit breaker. + + This exception wraps non-rate-limiting errors (network errors, timeouts, server errors, etc.) + and is excluded from circuit breaker failure counting. Only TelemetryRateLimitError should + open the circuit breaker. + + Attributes: + original_exception: The actual exception that occurred + """ + + def __init__(self, original_exception: Exception): + self.original_exception = original_exception + super().__init__(f"Non-rate-limit telemetry error: {original_exception}") diff --git a/src/databricks/sql/telemetry/circuit_breaker_manager.py b/src/databricks/sql/telemetry/circuit_breaker_manager.py new file mode 100644 index 000000000..852f0d916 --- /dev/null +++ b/src/databricks/sql/telemetry/circuit_breaker_manager.py @@ -0,0 +1,112 @@ +""" +Circuit breaker implementation for telemetry requests. + +This module provides circuit breaker functionality to prevent telemetry failures +from impacting the main SQL operations. It uses pybreaker library to implement +the circuit breaker pattern. +""" + +import logging +import threading +from typing import Dict + +import pybreaker +from pybreaker import CircuitBreaker, CircuitBreakerError, CircuitBreakerListener + +from databricks.sql.exc import TelemetryNonRateLimitError + +logger = logging.getLogger(__name__) + +# Circuit Breaker Constants +MINIMUM_CALLS = 20 # Number of failures before circuit opens +RESET_TIMEOUT = 30 # Seconds to wait before trying to close circuit +NAME_PREFIX = "telemetry-circuit-breaker" + +# Circuit Breaker State Constants (used in logging) +CIRCUIT_BREAKER_STATE_OPEN = "open" +CIRCUIT_BREAKER_STATE_CLOSED = "closed" +CIRCUIT_BREAKER_STATE_HALF_OPEN = "half-open" + +# Logging Message Constants +LOG_CIRCUIT_BREAKER_STATE_CHANGED = "Circuit breaker state changed from %s to %s for %s" +LOG_CIRCUIT_BREAKER_OPENED = ( + "Circuit breaker opened for %s - telemetry requests will be blocked" +) +LOG_CIRCUIT_BREAKER_CLOSED = ( + "Circuit breaker closed for %s - telemetry requests will be allowed" +) +LOG_CIRCUIT_BREAKER_HALF_OPEN = ( + "Circuit breaker half-open for %s - testing telemetry requests" +) + + +class CircuitBreakerStateListener(CircuitBreakerListener): + """Listener for circuit breaker state changes.""" + + def before_call(self, cb: CircuitBreaker, func, *args, **kwargs) -> None: + """Called before the circuit breaker calls a function.""" + pass + + def failure(self, cb: CircuitBreaker, exc: BaseException) -> None: + """Called when a function called by the circuit breaker fails.""" + pass + + def success(self, cb: CircuitBreaker) -> None: + """Called when a function called by the circuit breaker succeeds.""" + pass + + def state_change(self, cb: CircuitBreaker, old_state, new_state) -> None: + """Called when the circuit breaker state changes.""" + old_state_name = old_state.name if old_state else "None" + new_state_name = new_state.name if new_state else "None" + + logger.info( + LOG_CIRCUIT_BREAKER_STATE_CHANGED, old_state_name, new_state_name, cb.name + ) + + if new_state_name == CIRCUIT_BREAKER_STATE_OPEN: + logger.warning(LOG_CIRCUIT_BREAKER_OPENED, cb.name) + elif new_state_name == CIRCUIT_BREAKER_STATE_CLOSED: + logger.info(LOG_CIRCUIT_BREAKER_CLOSED, cb.name) + elif new_state_name == CIRCUIT_BREAKER_STATE_HALF_OPEN: + logger.info(LOG_CIRCUIT_BREAKER_HALF_OPEN, cb.name) + + +class CircuitBreakerManager: + """ + Manages circuit breaker instances for telemetry requests. + + Creates and caches circuit breaker instances per host to ensure telemetry + failures don't impact main SQL operations. + """ + + _instances: Dict[str, CircuitBreaker] = {} + _lock = threading.RLock() + + @classmethod + def get_circuit_breaker(cls, host: str) -> CircuitBreaker: + """ + Get or create a circuit breaker instance for the specified host. + + Args: + host: The hostname for which to get the circuit breaker + + Returns: + CircuitBreaker instance for the host + """ + with cls._lock: + if host not in cls._instances: + breaker = CircuitBreaker( + fail_max=MINIMUM_CALLS, + reset_timeout=RESET_TIMEOUT, + name=f"{NAME_PREFIX}-{host}", + exclude=[ + TelemetryNonRateLimitError + ], # Don't count these as failures + ) + # Add state change listener for logging + breaker.add_listener(CircuitBreakerStateListener()) + cls._instances[host] = breaker + logger.debug("Created circuit breaker for host: %s", host) + + return cls._instances[host] diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 134757fe5..177d5445c 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -41,6 +41,11 @@ from databricks.sql.common.feature_flag import FeatureFlagsContextFactory from databricks.sql.common.unified_http_client import UnifiedHttpClient from databricks.sql.common.http import HttpMethod +from databricks.sql.telemetry.telemetry_push_client import ( + ITelemetryPushClient, + TelemetryPushClient, + CircuitBreakerTelemetryPushClient, +) if TYPE_CHECKING: from databricks.sql.client import Connection @@ -166,21 +171,21 @@ class TelemetryClient(BaseTelemetryClient): def __init__( self, - telemetry_enabled, - session_id_hex, + telemetry_enabled: bool, + session_id_hex: str, auth_provider, - host_url, + host_url: str, executor, - batch_size, + batch_size: int, client_context, - ): + ) -> None: logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex) self._telemetry_enabled = telemetry_enabled self._batch_size = batch_size self._session_id_hex = session_id_hex self._auth_provider = auth_provider self._user_agent = None - self._events_batch = [] + self._events_batch: list = [] self._lock = threading.RLock() self._driver_connection_params = None self._host_url = host_url @@ -189,6 +194,19 @@ def __init__( # Create own HTTP client from client context self._http_client = UnifiedHttpClient(client_context) + # Create telemetry push client based on circuit breaker enabled flag + if client_context.telemetry_circuit_breaker_enabled: + # Create circuit breaker telemetry push client (circuit breakers created on-demand) + self._telemetry_push_client: ITelemetryPushClient = ( + CircuitBreakerTelemetryPushClient( + TelemetryPushClient(self._http_client), + host_url, + ) + ) + else: + # Circuit breaker disabled - use direct telemetry push client + self._telemetry_push_client = TelemetryPushClient(self._http_client) + def _export_event(self, event): """Add an event to the batch queue and flush if batch is full""" logger.debug("Exporting event for connection %s", self._session_id_hex) @@ -254,7 +272,7 @@ def _send_telemetry(self, events): def _send_with_unified_client(self, url, data, headers, timeout=900): """Helper method to send telemetry using the unified HTTP client.""" try: - response = self._http_client.request( + response = self._telemetry_push_client.request( HttpMethod.POST, url, body=data, headers=headers, timeout=timeout ) return response diff --git a/src/databricks/sql/telemetry/telemetry_push_client.py b/src/databricks/sql/telemetry/telemetry_push_client.py new file mode 100644 index 000000000..461a57738 --- /dev/null +++ b/src/databricks/sql/telemetry/telemetry_push_client.py @@ -0,0 +1,201 @@ +""" +Telemetry push client interface and implementations. + +This module provides an interface for telemetry push clients with two implementations: +1. TelemetryPushClient - Direct HTTP client implementation +2. CircuitBreakerTelemetryPushClient - Circuit breaker wrapper implementation +""" + +import logging +from abc import ABC, abstractmethod +from typing import Dict, Any, Optional + +try: + from urllib3 import BaseHTTPResponse +except ImportError: + from urllib3 import HTTPResponse as BaseHTTPResponse +from pybreaker import CircuitBreakerError + +from databricks.sql.common.unified_http_client import UnifiedHttpClient +from databricks.sql.common.http import HttpMethod +from databricks.sql.exc import ( + TelemetryRateLimitError, + TelemetryNonRateLimitError, + RequestError, +) +from databricks.sql.telemetry.circuit_breaker_manager import CircuitBreakerManager + +logger = logging.getLogger(__name__) + + +class ITelemetryPushClient(ABC): + """Interface for telemetry push clients.""" + + @abstractmethod + def request( + self, + method: HttpMethod, + url: str, + headers: Optional[Dict[str, str]] = None, + **kwargs, + ) -> BaseHTTPResponse: + """Make an HTTP request.""" + pass + + +class TelemetryPushClient(ITelemetryPushClient): + """Direct HTTP client implementation for telemetry requests.""" + + def __init__(self, http_client: UnifiedHttpClient): + """ + Initialize the telemetry push client. + + Args: + http_client: The underlying HTTP client + """ + self._http_client = http_client + logger.debug("TelemetryPushClient initialized") + + def request( + self, + method: HttpMethod, + url: str, + headers: Optional[Dict[str, str]] = None, + **kwargs, + ) -> BaseHTTPResponse: + """Make an HTTP request using the underlying HTTP client.""" + return self._http_client.request(method, url, headers, **kwargs) + + +class CircuitBreakerTelemetryPushClient(ITelemetryPushClient): + """Circuit breaker wrapper implementation for telemetry requests.""" + + def __init__(self, delegate: ITelemetryPushClient, host: str): + """ + Initialize the circuit breaker telemetry push client. + + Args: + delegate: The underlying telemetry push client to wrap + host: The hostname for circuit breaker identification + """ + self._delegate = delegate + self._host = host + + # Get circuit breaker for this host (creates if doesn't exist) + self._circuit_breaker = CircuitBreakerManager.get_circuit_breaker(host) + + logger.debug( + "CircuitBreakerTelemetryPushClient initialized for host %s", + host, + ) + + def _make_request_and_check_status( + self, + method: HttpMethod, + url: str, + headers: Optional[Dict[str, str]], + **kwargs, + ) -> BaseHTTPResponse: + """ + Make the request and check response status. + + Raises TelemetryRateLimitError for 429/503 (circuit breaker counts these). + Wraps other errors in TelemetryNonRateLimitError (circuit breaker excludes these). + + Args: + method: HTTP method + url: Request URL + headers: Request headers + **kwargs: Additional request parameters + + Returns: + HTTP response + + Raises: + TelemetryRateLimitError: For 429/503 status codes (circuit breaker counts) + TelemetryNonRateLimitError: For other errors (circuit breaker excludes) + """ + try: + response = self._delegate.request(method, url, headers, **kwargs) + + # Check for rate limiting or service unavailable + if response.status in [429, 503]: + logger.warning( + "Telemetry endpoint returned %d for host %s, triggering circuit breaker", + response.status, + self._host, + ) + raise TelemetryRateLimitError( + f"Telemetry endpoint rate limited or unavailable: {response.status}" + ) + + return response + + except Exception as e: + # Don't catch TelemetryRateLimitError - let it propagate to circuit breaker + if isinstance(e, TelemetryRateLimitError): + raise + + # Check if it's a RequestError with rate limiting status code (exhausted retries) + if isinstance(e, RequestError): + http_code = ( + e.context.get("http-code") + if hasattr(e, "context") and e.context + else None + ) + + if http_code in [429, 503]: + logger.debug( + "Telemetry retries exhausted with status %d for host %s, triggering circuit breaker", + http_code, + self._host, + ) + raise TelemetryRateLimitError( + f"Telemetry rate limited after retries: {http_code}" + ) + + # NOT rate limiting (500 errors, network errors, timeouts, etc.) + # Wrap in TelemetryNonRateLimitError so circuit breaker excludes it + logger.debug( + "Non-rate-limit telemetry error for host %s: %s, wrapping to exclude from circuit breaker", + self._host, + e, + ) + raise TelemetryNonRateLimitError(e) from e + + def request( + self, + method: HttpMethod, + url: str, + headers: Optional[Dict[str, str]] = None, + **kwargs, + ) -> BaseHTTPResponse: + """ + Make an HTTP request with circuit breaker protection. + + Circuit breaker only opens for TelemetryRateLimitError (429/503 responses). + Other errors are wrapped in TelemetryNonRateLimitError and excluded from circuit breaker. + All exceptions propagate to caller (TelemetryClient callback handles them). + """ + try: + # Use circuit breaker to protect the request + # TelemetryRateLimitError will trigger circuit breaker + # TelemetryNonRateLimitError is excluded from circuit breaker + return self._circuit_breaker.call( + self._make_request_and_check_status, + method, + url, + headers, + **kwargs, + ) + + except TelemetryNonRateLimitError as e: + # Unwrap and re-raise original exception + # Circuit breaker didn't count this, but caller should handle it + logger.debug( + "Non-rate-limit telemetry error for host %s, re-raising original: %s", + self._host, + e.original_exception, + ) + raise e.original_exception from e + # All other exceptions (TelemetryRateLimitError, CircuitBreakerError) propagate as-is diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 9f96e8743..b46784b10 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -922,4 +922,7 @@ def build_client_context(server_hostname: str, version: str, **kwargs): proxy_auth_method=kwargs.get("_proxy_auth_method"), pool_connections=kwargs.get("_pool_connections"), pool_maxsize=kwargs.get("_pool_maxsize"), + telemetry_circuit_breaker_enabled=kwargs.get( + "_telemetry_circuit_breaker_enabled" + ), ) diff --git a/tests/e2e/test_circuit_breaker.py b/tests/e2e/test_circuit_breaker.py new file mode 100644 index 000000000..45c494d19 --- /dev/null +++ b/tests/e2e/test_circuit_breaker.py @@ -0,0 +1,232 @@ +""" +E2E tests for circuit breaker functionality in telemetry. + +This test suite verifies: +1. Circuit breaker opens after rate limit failures (429/503) +2. Circuit breaker blocks subsequent calls while open +3. Circuit breaker does not trigger for non-rate-limit errors +4. Circuit breaker can be disabled via configuration flag +5. Circuit breaker closes after reset timeout + +Run with: + pytest tests/e2e/test_circuit_breaker.py -v -s +""" + +import time +from unittest.mock import patch, MagicMock + +import pytest +from pybreaker import STATE_OPEN, STATE_CLOSED, STATE_HALF_OPEN +from urllib3 import HTTPResponse + +import databricks.sql as sql +from databricks.sql.telemetry.circuit_breaker_manager import CircuitBreakerManager + + +@pytest.fixture(autouse=True) +def aggressive_circuit_breaker_config(): + """ + Configure circuit breaker to be aggressive for faster testing. + Opens after 2 failures instead of 20, with 5 second timeout. + """ + from databricks.sql.telemetry import circuit_breaker_manager + + original_minimum_calls = circuit_breaker_manager.MINIMUM_CALLS + original_reset_timeout = circuit_breaker_manager.RESET_TIMEOUT + + circuit_breaker_manager.MINIMUM_CALLS = 2 + circuit_breaker_manager.RESET_TIMEOUT = 5 + + CircuitBreakerManager._instances.clear() + + yield + + circuit_breaker_manager.MINIMUM_CALLS = original_minimum_calls + circuit_breaker_manager.RESET_TIMEOUT = original_reset_timeout + CircuitBreakerManager._instances.clear() + + +class TestCircuitBreakerTelemetry: + """Tests for circuit breaker functionality with telemetry""" + + @pytest.fixture(autouse=True) + def get_details(self, connection_details): + """Get connection details from pytest fixture""" + self.arguments = connection_details.copy() + + def create_mock_response(self, status_code): + """Helper to create mock HTTP response.""" + response = MagicMock(spec=HTTPResponse) + response.status = status_code + response.data = { + 429: b"Too Many Requests", + 503: b"Service Unavailable", + 500: b"Internal Server Error", + }.get(status_code, b"Response") + return response + + @pytest.mark.parametrize("status_code,should_trigger", [ + (429, True), + (503, True), + (500, False), + ]) + def test_circuit_breaker_triggers_for_rate_limit_codes(self, status_code, should_trigger): + """ + Verify circuit breaker opens for rate-limit codes (429/503) but not others (500). + """ + request_count = {"count": 0} + + def mock_request(*args, **kwargs): + request_count["count"] += 1 + return self.create_mock_response(status_code) + + with patch( + "databricks.sql.telemetry.telemetry_push_client.TelemetryPushClient.request", + side_effect=mock_request, + ): + with sql.connect( + server_hostname=self.arguments["host"], + http_path=self.arguments["http_path"], + access_token=self.arguments.get("access_token"), + force_enable_telemetry=True, + telemetry_batch_size=1, + _telemetry_circuit_breaker_enabled=True, + ) as conn: + circuit_breaker = CircuitBreakerManager.get_circuit_breaker( + self.arguments["host"] + ) + + assert circuit_breaker.current_state == STATE_CLOSED + + cursor = conn.cursor() + + # Execute queries to trigger telemetry + for i in range(1, 6): + cursor.execute(f"SELECT {i}") + cursor.fetchone() + time.sleep(0.5) + + if should_trigger: + # Circuit should be OPEN after 2 rate-limit failures + assert circuit_breaker.current_state == STATE_OPEN + assert circuit_breaker.fail_counter == 2 + + # Track requests before another query + requests_before = request_count["count"] + cursor.execute("SELECT 99") + cursor.fetchone() + time.sleep(1) + + # No new telemetry requests (circuit is open) + assert request_count["count"] == requests_before + else: + # Circuit should remain CLOSED for non-rate-limit errors + assert circuit_breaker.current_state == STATE_CLOSED + assert circuit_breaker.fail_counter == 0 + assert request_count["count"] >= 5 + + def test_circuit_breaker_disabled_allows_all_calls(self): + """ + Verify that when circuit breaker is disabled, all calls go through + even with rate limit errors. + """ + request_count = {"count": 0} + + def mock_rate_limited_request(*args, **kwargs): + request_count["count"] += 1 + return self.create_mock_response(429) + + with patch( + "databricks.sql.telemetry.telemetry_push_client.TelemetryPushClient.request", + side_effect=mock_rate_limited_request, + ): + with sql.connect( + server_hostname=self.arguments["host"], + http_path=self.arguments["http_path"], + access_token=self.arguments.get("access_token"), + force_enable_telemetry=True, + telemetry_batch_size=1, + _telemetry_circuit_breaker_enabled=False, # Disabled + ) as conn: + cursor = conn.cursor() + + for i in range(5): + cursor.execute(f"SELECT {i}") + cursor.fetchone() + time.sleep(0.3) + + assert request_count["count"] >= 5 + + def test_circuit_breaker_recovers_after_reset_timeout(self): + """ + Verify circuit breaker transitions to HALF_OPEN after reset timeout + and eventually CLOSES if requests succeed. + """ + request_count = {"count": 0} + fail_requests = {"enabled": True} + + def mock_conditional_request(*args, **kwargs): + request_count["count"] += 1 + status = 429 if fail_requests["enabled"] else 200 + return self.create_mock_response(status) + + with patch( + "databricks.sql.telemetry.telemetry_push_client.TelemetryPushClient.request", + side_effect=mock_conditional_request, + ): + with sql.connect( + server_hostname=self.arguments["host"], + http_path=self.arguments["http_path"], + access_token=self.arguments.get("access_token"), + force_enable_telemetry=True, + telemetry_batch_size=1, + _telemetry_circuit_breaker_enabled=True, + ) as conn: + circuit_breaker = CircuitBreakerManager.get_circuit_breaker( + self.arguments["host"] + ) + + cursor = conn.cursor() + + # Trigger failures to open circuit + cursor.execute("SELECT 1") + cursor.fetchone() + time.sleep(1) + + cursor.execute("SELECT 2") + cursor.fetchone() + time.sleep(2) + + assert circuit_breaker.current_state == STATE_OPEN + + # Wait for reset timeout (5 seconds in test) + time.sleep(6) + + # Now make requests succeed + fail_requests["enabled"] = False + + # Execute query to trigger HALF_OPEN state + cursor.execute("SELECT 3") + cursor.fetchone() + time.sleep(1) + + # Circuit should be recovering + assert circuit_breaker.current_state in [ + STATE_HALF_OPEN, + STATE_CLOSED, + ], f"Circuit should be recovering, but is {circuit_breaker.current_state}" + + # Execute more queries to fully recover + cursor.execute("SELECT 4") + cursor.fetchone() + time.sleep(1) + + current_state = circuit_breaker.current_state + assert current_state in [ + STATE_CLOSED, + STATE_HALF_OPEN, + ], f"Circuit should recover to CLOSED or HALF_OPEN, got {current_state}" + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/unit/test_circuit_breaker_http_client.py b/tests/unit/test_circuit_breaker_http_client.py new file mode 100644 index 000000000..432ca1be3 --- /dev/null +++ b/tests/unit/test_circuit_breaker_http_client.py @@ -0,0 +1,208 @@ +""" +Unit tests for telemetry push client functionality. +""" + +import pytest +from unittest.mock import Mock, patch, MagicMock + +from databricks.sql.telemetry.telemetry_push_client import ( + ITelemetryPushClient, + TelemetryPushClient, + CircuitBreakerTelemetryPushClient, +) +from databricks.sql.common.http import HttpMethod +from pybreaker import CircuitBreakerError + + +class TestTelemetryPushClient: + """Test cases for TelemetryPushClient.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_http_client = Mock() + self.client = TelemetryPushClient(self.mock_http_client) + + def test_initialization(self): + """Test client initialization.""" + assert self.client._http_client == self.mock_http_client + + def test_request_delegates_to_http_client(self): + """Test that request delegates to underlying HTTP client.""" + mock_response = Mock() + self.mock_http_client.request.return_value = mock_response + + response = self.client.request(HttpMethod.POST, "https://test.com", {}) + + assert response == mock_response + self.mock_http_client.request.assert_called_once() + + def test_direct_client_has_no_circuit_breaker(self): + """Test that direct client does not have circuit breaker functionality.""" + # Direct client should work without circuit breaker + assert isinstance(self.client, TelemetryPushClient) + + +class TestCircuitBreakerTelemetryPushClient: + """Test cases for CircuitBreakerTelemetryPushClient.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_delegate = Mock(spec=ITelemetryPushClient) + self.host = "test-host.example.com" + self.client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) + + def test_initialization(self): + """Test client initialization.""" + assert self.client._delegate == self.mock_delegate + assert self.client._host == self.host + assert self.client._circuit_breaker is not None + + def test_request_enabled_success(self): + """Test successful request when circuit breaker is enabled.""" + mock_response = Mock() + self.mock_delegate.request.return_value = mock_response + + response = self.client.request(HttpMethod.POST, "https://test.com", {}) + + assert response == mock_response + self.mock_delegate.request.assert_called_once() + + def test_request_enabled_circuit_breaker_error(self): + """Test request when circuit breaker is open - should raise CircuitBreakerError.""" + # Mock circuit breaker to raise CircuitBreakerError + with patch.object( + self.client._circuit_breaker, + "call", + side_effect=CircuitBreakerError("Circuit is open"), + ): + # Circuit breaker open should raise (caller handles it) + with pytest.raises(CircuitBreakerError): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + def test_request_enabled_other_error(self): + """Test request when other error occurs - should raise original exception.""" + # Mock delegate to raise a different error (not rate limiting) + self.mock_delegate.request.side_effect = ValueError("Network error") + + # Non-rate-limit errors are unwrapped and raised + with pytest.raises(ValueError, match="Network error"): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + def test_is_circuit_breaker_enabled(self): + """Test checking if circuit breaker is enabled.""" + assert self.client._circuit_breaker is not None + + def test_circuit_breaker_state_logging(self): + """Test that circuit breaker errors are raised (no longer silent).""" + with patch.object( + self.client._circuit_breaker, + "call", + side_effect=CircuitBreakerError("Circuit is open"), + ): + # Should raise CircuitBreakerError (caller handles it) + with pytest.raises(CircuitBreakerError): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + def test_other_error_logging(self): + """Test that other errors are wrapped, logged, then unwrapped and raised.""" + with patch( + "databricks.sql.telemetry.telemetry_push_client.logger" + ) as mock_logger: + self.mock_delegate.request.side_effect = ValueError("Network error") + + # Should raise the original ValueError + with pytest.raises(ValueError, match="Network error"): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + # Check that debug was logged (for wrapping and/or unwrapping) + assert mock_logger.debug.call_count >= 1 + + +class TestCircuitBreakerTelemetryPushClientIntegration: + """Integration tests for CircuitBreakerTelemetryPushClient.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_delegate = Mock() + self.host = "test-host.example.com" + + def test_circuit_breaker_opens_after_failures(self): + """Test that circuit breaker opens after repeated failures (429/503 errors).""" + from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerManager, + MINIMUM_CALLS, + ) + from databricks.sql.exc import TelemetryRateLimitError + + # Clear any existing state + CircuitBreakerManager._instances.clear() + + client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) + + # Simulate rate limit failures (429) + mock_response = Mock() + mock_response.status = 429 + self.mock_delegate.request.return_value = mock_response + + # All calls should raise TelemetryRateLimitError + # After MINIMUM_CALLS failures, circuit breaker opens + rate_limit_error_count = 0 + circuit_breaker_error_count = 0 + + for i in range(MINIMUM_CALLS + 5): + try: + client.request(HttpMethod.POST, "https://test.com", {}) + except TelemetryRateLimitError: + rate_limit_error_count += 1 + except CircuitBreakerError: + circuit_breaker_error_count += 1 + + # Should have some rate limit errors before circuit opens, then circuit breaker errors + assert rate_limit_error_count >= MINIMUM_CALLS - 1 + assert circuit_breaker_error_count > 0 + + def test_circuit_breaker_recovers_after_success(self): + """Test that circuit breaker recovers after successful calls.""" + from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerManager, + MINIMUM_CALLS, + RESET_TIMEOUT, + ) + import time + + # Clear any existing state + CircuitBreakerManager._instances.clear() + + client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) + + # Simulate rate limit failures first (429) + from databricks.sql.exc import TelemetryRateLimitError + from pybreaker import CircuitBreakerError + + mock_rate_limit_response = Mock() + mock_rate_limit_response.status = 429 + self.mock_delegate.request.return_value = mock_rate_limit_response + + # Trigger enough rate limit failures to open circuit + for i in range(MINIMUM_CALLS + 5): + try: + client.request(HttpMethod.POST, "https://test.com", {}) + except (TelemetryRateLimitError, CircuitBreakerError): + pass # Expected - circuit breaker opens after MINIMUM_CALLS failures + + # Circuit should be open now - raises CircuitBreakerError + with pytest.raises(CircuitBreakerError): + client.request(HttpMethod.POST, "https://test.com", {}) + + # Wait for reset timeout + time.sleep(RESET_TIMEOUT + 1.0) + + # Simulate successful calls (200 response) + mock_success_response = Mock() + mock_success_response.status = 200 + self.mock_delegate.request.return_value = mock_success_response + + # Should work again with actual success response + response = client.request(HttpMethod.POST, "https://test.com", {}) + assert response is not None + assert response.status == 200 diff --git a/tests/unit/test_circuit_breaker_manager.py b/tests/unit/test_circuit_breaker_manager.py new file mode 100644 index 000000000..e8ed4e809 --- /dev/null +++ b/tests/unit/test_circuit_breaker_manager.py @@ -0,0 +1,160 @@ +""" +Unit tests for circuit breaker manager functionality. +""" + +import pytest +import threading +import time +from unittest.mock import Mock, patch + +from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerManager, + MINIMUM_CALLS, + RESET_TIMEOUT, + NAME_PREFIX as CIRCUIT_BREAKER_NAME, +) +from pybreaker import CircuitBreakerError + + +class TestCircuitBreakerManager: + """Test cases for CircuitBreakerManager.""" + + def setup_method(self): + """Set up test fixtures.""" + CircuitBreakerManager._instances.clear() + + def teardown_method(self): + """Clean up after tests.""" + CircuitBreakerManager._instances.clear() + + def test_get_circuit_breaker_creates_instance(self): + """Test getting circuit breaker creates instance with correct config.""" + breaker = CircuitBreakerManager.get_circuit_breaker("test-host") + + assert breaker.name == "telemetry-circuit-breaker-test-host" + assert breaker.fail_max == MINIMUM_CALLS + + def test_get_circuit_breaker_same_host_returns_same_instance(self): + """Test that same host returns same circuit breaker instance.""" + breaker1 = CircuitBreakerManager.get_circuit_breaker("test-host") + breaker2 = CircuitBreakerManager.get_circuit_breaker("test-host") + + assert breaker1 is breaker2 + + def test_get_circuit_breaker_different_hosts_return_different_instances(self): + """Test that different hosts return different circuit breaker instances.""" + breaker1 = CircuitBreakerManager.get_circuit_breaker("host1") + breaker2 = CircuitBreakerManager.get_circuit_breaker("host2") + + assert breaker1 is not breaker2 + assert breaker1.name != breaker2.name + + def test_thread_safety(self): + """Test thread safety of circuit breaker manager.""" + results = [] + + def get_breaker(host): + breaker = CircuitBreakerManager.get_circuit_breaker(host) + results.append(breaker) + + threads = [] + for i in range(10): + thread = threading.Thread(target=get_breaker, args=(f"host{i % 3}",)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + assert len(results) == 10 + + # All breakers for same host should be same instance + host0_breakers = [b for b in results if b.name.endswith("host0")] + assert all(b is host0_breakers[0] for b in host0_breakers) + + +class TestCircuitBreakerIntegration: + """Integration tests for circuit breaker functionality.""" + + def setup_method(self): + """Set up test fixtures.""" + CircuitBreakerManager._instances.clear() + + def teardown_method(self): + """Clean up after tests.""" + CircuitBreakerManager._instances.clear() + + def test_circuit_breaker_state_transitions(self): + """Test circuit breaker state transitions from closed to open.""" + breaker = CircuitBreakerManager.get_circuit_breaker("test-host") + + assert breaker.current_state == "closed" + + def failing_func(): + raise Exception("Simulated failure") + + # Trigger failures up to the threshold (MINIMUM_CALLS = 20) + for _ in range(MINIMUM_CALLS): + with pytest.raises(Exception): + breaker.call(failing_func) + + # Next call should fail with CircuitBreakerError (circuit is now open) + with pytest.raises(CircuitBreakerError): + breaker.call(failing_func) + + assert breaker.current_state == "open" + + def test_circuit_breaker_recovery(self): + """Test circuit breaker recovery after failures.""" + breaker = CircuitBreakerManager.get_circuit_breaker("test-host") + + def failing_func(): + raise Exception("Simulated failure") + + # Trigger failures up to the threshold + for _ in range(MINIMUM_CALLS): + with pytest.raises(Exception): + breaker.call(failing_func) + + assert breaker.current_state == "open" + + # Wait for reset timeout + time.sleep(RESET_TIMEOUT + 1.0) + + # Try successful call to close circuit breaker + def successful_func(): + return "success" + + try: + result = breaker.call(successful_func) + assert result == "success" + except CircuitBreakerError: + pass # Circuit might still be open, acceptable + + assert breaker.current_state in ["closed", "half-open", "open"] + + @pytest.mark.parametrize("old_state,new_state", [ + ("closed", "open"), + ("open", "half-open"), + ("half-open", "closed"), + ("closed", "half-open"), + ]) + def test_circuit_breaker_state_listener_transitions(self, old_state, new_state): + """Test circuit breaker state listener logs all state transitions.""" + from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerStateListener, + ) + + listener = CircuitBreakerStateListener() + mock_cb = Mock() + mock_cb.name = "test-breaker" + + mock_old_state = Mock() + mock_old_state.name = old_state + + mock_new_state = Mock() + mock_new_state.name = new_state + + with patch("databricks.sql.telemetry.circuit_breaker_manager.logger") as mock_logger: + listener.state_change(mock_cb, mock_old_state, mock_new_state) + mock_logger.info.assert_called() diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 36141ee2b..6f5a01c7b 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -37,7 +37,9 @@ def mock_telemetry_client(): client_context = MagicMock() # Patch the _setup_pool_manager method to avoid SSL file loading - with patch('databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers'): + with patch( + "databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers" + ): return TelemetryClient( telemetry_enabled=True, session_id_hex=session_id, @@ -95,7 +97,7 @@ def test_network_request_flow(self, mock_http_request, mock_telemetry_client): mock_response.status = 200 mock_response.status_code = 200 mock_http_request.return_value = mock_response - + client = mock_telemetry_client # Create mock events @@ -231,7 +233,9 @@ def test_client_lifecycle_flow(self): client_context = MagicMock() # Initialize enabled client - with patch('databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers'): + with patch( + "databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers" + ): TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, session_id_hex=session_id_hex, @@ -299,7 +303,9 @@ def test_factory_shutdown_flow(self): client_context = MagicMock() # Initialize multiple clients - with patch('databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers'): + with patch( + "databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers" + ): for session in [session1, session2]: TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, @@ -382,8 +388,10 @@ def test_telemetry_enabled_when_flag_is_true(self, mock_http_request, MockSessio mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-true" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = False # Connection starts closed for test cleanup - + mock_session_instance.is_open = ( + False # Connection starts closed for test cleanup + ) + # Set up mock HTTP client on the session mock_http_client = MagicMock() mock_http_client.request = mock_http_request @@ -410,8 +418,10 @@ def test_telemetry_disabled_when_flag_is_false( mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-false" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = False # Connection starts closed for test cleanup - + mock_session_instance.is_open = ( + False # Connection starts closed for test cleanup + ) + # Set up mock HTTP client on the session mock_http_client = MagicMock() mock_http_client.request = mock_http_request @@ -438,8 +448,10 @@ def test_telemetry_disabled_when_flag_request_fails( mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-fail" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = False # Connection starts closed for test cleanup - + mock_session_instance.is_open = ( + False # Connection starts closed for test cleanup + ) + # Set up mock HTTP client on the session mock_http_client = MagicMock() mock_http_client.request = mock_http_request diff --git a/tests/unit/test_telemetry_push_client.py b/tests/unit/test_telemetry_push_client.py new file mode 100644 index 000000000..0e9455e1f --- /dev/null +++ b/tests/unit/test_telemetry_push_client.py @@ -0,0 +1,213 @@ +""" +Unit tests for telemetry push client functionality. +""" + +import pytest +from unittest.mock import Mock, patch + +from databricks.sql.telemetry.telemetry_push_client import ( + ITelemetryPushClient, + TelemetryPushClient, + CircuitBreakerTelemetryPushClient, +) +from databricks.sql.common.http import HttpMethod +from databricks.sql.exc import TelemetryRateLimitError +from pybreaker import CircuitBreakerError + + +class TestTelemetryPushClient: + """Test cases for TelemetryPushClient.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_http_client = Mock() + self.client = TelemetryPushClient(self.mock_http_client) + + def test_initialization(self): + """Test client initialization.""" + assert self.client._http_client == self.mock_http_client + + def test_request_delegates_to_http_client(self): + """Test that request delegates to underlying HTTP client.""" + mock_response = Mock() + self.mock_http_client.request.return_value = mock_response + + response = self.client.request(HttpMethod.POST, "https://test.com", {}) + + assert response == mock_response + self.mock_http_client.request.assert_called_once() + + +class TestCircuitBreakerTelemetryPushClient: + """Test cases for CircuitBreakerTelemetryPushClient.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_delegate = Mock(spec=ITelemetryPushClient) + self.host = "test-host.example.com" + self.client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) + + def test_initialization(self): + """Test client initialization.""" + assert self.client._delegate == self.mock_delegate + assert self.client._host == self.host + assert self.client._circuit_breaker is not None + + def test_request_success(self): + """Test successful request when circuit breaker is enabled.""" + mock_response = Mock() + self.mock_delegate.request.return_value = mock_response + + response = self.client.request(HttpMethod.POST, "https://test.com", {}) + + assert response == mock_response + self.mock_delegate.request.assert_called_once() + + def test_request_circuit_breaker_open(self): + """Test request when circuit breaker is open raises CircuitBreakerError.""" + with patch.object( + self.client._circuit_breaker, + "call", + side_effect=CircuitBreakerError("Circuit is open"), + ): + with pytest.raises(CircuitBreakerError): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + def test_request_other_error(self): + """Test request when other error occurs raises original exception.""" + self.mock_delegate.request.side_effect = ValueError("Network error") + + with pytest.raises(ValueError, match="Network error"): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + @pytest.mark.parametrize("status_code,expected_error", [ + (429, TelemetryRateLimitError), + (503, TelemetryRateLimitError), + ]) + def test_request_rate_limit_codes(self, status_code, expected_error): + """Test that rate-limit status codes raise TelemetryRateLimitError.""" + mock_response = Mock() + mock_response.status = status_code + self.mock_delegate.request.return_value = mock_response + + with pytest.raises(expected_error): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + def test_request_non_rate_limit_code(self): + """Test that non-rate-limit status codes return response.""" + mock_response = Mock() + mock_response.status = 500 + mock_response.data = b'Server error' + self.mock_delegate.request.return_value = mock_response + + response = self.client.request(HttpMethod.POST, "https://test.com", {}) + assert response is not None + assert response.status == 500 + + def test_rate_limit_error_logging(self): + """Test that rate limit errors are logged with circuit breaker context.""" + with patch("databricks.sql.telemetry.telemetry_push_client.logger") as mock_logger: + mock_response = Mock() + mock_response.status = 429 + self.mock_delegate.request.return_value = mock_response + + with pytest.raises(TelemetryRateLimitError): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + mock_logger.warning.assert_called() + warning_args = mock_logger.warning.call_args[0] + assert "429" in str(warning_args) + assert "circuit breaker" in warning_args[0] + + def test_other_error_logging(self): + """Test that other errors are logged during wrapping/unwrapping.""" + with patch("databricks.sql.telemetry.telemetry_push_client.logger") as mock_logger: + self.mock_delegate.request.side_effect = ValueError("Network error") + + with pytest.raises(ValueError, match="Network error"): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + assert mock_logger.debug.call_count >= 1 + + +class TestCircuitBreakerTelemetryPushClientIntegration: + """Integration tests for CircuitBreakerTelemetryPushClient.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_delegate = Mock() + self.host = "test-host.example.com" + from databricks.sql.telemetry.circuit_breaker_manager import CircuitBreakerManager + CircuitBreakerManager._instances.clear() + + def test_circuit_breaker_opens_after_failures(self): + """Test that circuit breaker opens after repeated failures (429/503 errors).""" + from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerManager, + MINIMUM_CALLS, + ) + + CircuitBreakerManager._instances.clear() + client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) + + mock_response = Mock() + mock_response.status = 429 + self.mock_delegate.request.return_value = mock_response + + rate_limit_error_count = 0 + circuit_breaker_error_count = 0 + + for _ in range(MINIMUM_CALLS + 5): + try: + client.request(HttpMethod.POST, "https://test.com", {}) + except TelemetryRateLimitError: + rate_limit_error_count += 1 + except CircuitBreakerError: + circuit_breaker_error_count += 1 + + assert rate_limit_error_count >= MINIMUM_CALLS - 1 + assert circuit_breaker_error_count > 0 + + def test_circuit_breaker_recovers_after_success(self): + """Test that circuit breaker recovers after successful calls.""" + import time + from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerManager, + MINIMUM_CALLS, + RESET_TIMEOUT, + ) + + CircuitBreakerManager._instances.clear() + client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) + + # Trigger failures + mock_rate_limit_response = Mock() + mock_rate_limit_response.status = 429 + self.mock_delegate.request.return_value = mock_rate_limit_response + + for _ in range(MINIMUM_CALLS + 5): + try: + client.request(HttpMethod.POST, "https://test.com", {}) + except (TelemetryRateLimitError, CircuitBreakerError): + pass + + # Circuit should be open + with pytest.raises(CircuitBreakerError): + client.request(HttpMethod.POST, "https://test.com", {}) + + # Wait for reset timeout + time.sleep(RESET_TIMEOUT + 1.0) + + # Simulate success + mock_success_response = Mock() + mock_success_response.status = 200 + self.mock_delegate.request.return_value = mock_success_response + + response = client.request(HttpMethod.POST, "https://test.com", {}) + assert response is not None + assert response.status == 200 + + def test_urllib3_import_fallback(self): + """Test that the urllib3 import fallback works correctly.""" + from databricks.sql.telemetry.telemetry_push_client import BaseHTTPResponse + assert BaseHTTPResponse is not None diff --git a/tests/unit/test_telemetry_request_error_handling.py b/tests/unit/test_telemetry_request_error_handling.py new file mode 100644 index 000000000..aa31f6628 --- /dev/null +++ b/tests/unit/test_telemetry_request_error_handling.py @@ -0,0 +1,96 @@ +""" +Unit tests specifically for telemetry_push_client RequestError handling +with http-code context extraction for rate limiting detection. +""" + +import pytest +from unittest.mock import Mock + +from databricks.sql.telemetry.telemetry_push_client import ( + CircuitBreakerTelemetryPushClient, + TelemetryPushClient, +) +from databricks.sql.common.http import HttpMethod +from databricks.sql.exc import RequestError, TelemetryRateLimitError +from databricks.sql.telemetry.circuit_breaker_manager import CircuitBreakerManager + + +class TestTelemetryPushClientRequestErrorHandling: + """Test RequestError handling and http-code context extraction.""" + + @pytest.fixture + def setup_circuit_breaker(self): + """Setup circuit breaker for testing.""" + CircuitBreakerManager._instances.clear() + yield + CircuitBreakerManager._instances.clear() + + @pytest.fixture + def mock_delegate(self): + """Create mock delegate client.""" + return Mock(spec=TelemetryPushClient) + + @pytest.fixture + def client(self, mock_delegate, setup_circuit_breaker): + """Create CircuitBreakerTelemetryPushClient instance.""" + return CircuitBreakerTelemetryPushClient(mock_delegate, "test-host.example.com") + + @pytest.mark.parametrize("status_code", [429, 503]) + def test_request_error_with_rate_limit_codes(self, client, mock_delegate, status_code): + """Test that RequestError with rate-limit codes raises TelemetryRateLimitError.""" + request_error = RequestError("HTTP request failed", context={"http-code": status_code}) + mock_delegate.request.side_effect = request_error + + with pytest.raises(TelemetryRateLimitError): + client.request(HttpMethod.POST, "https://test.com", {}) + + @pytest.mark.parametrize("status_code", [500, 400, 404]) + def test_request_error_with_non_rate_limit_codes(self, client, mock_delegate, status_code): + """Test that RequestError with non-rate-limit codes raises original RequestError.""" + request_error = RequestError("HTTP request failed", context={"http-code": status_code}) + mock_delegate.request.side_effect = request_error + + with pytest.raises(RequestError, match="HTTP request failed"): + client.request(HttpMethod.POST, "https://test.com", {}) + + @pytest.mark.parametrize("context", [{}, None, "429"]) + def test_request_error_with_invalid_context(self, client, mock_delegate, context): + """Test RequestError with invalid/missing context raises original error.""" + request_error = RequestError("HTTP request failed") + if context == "429": + # Edge case: http-code as string instead of int + request_error.context = {"http-code": context} + else: + request_error.context = context + mock_delegate.request.side_effect = request_error + + with pytest.raises(RequestError, match="HTTP request failed"): + client.request(HttpMethod.POST, "https://test.com", {}) + + def test_request_error_missing_context_attribute(self, client, mock_delegate): + """Test RequestError without context attribute raises original error.""" + request_error = RequestError("HTTP request failed") + if hasattr(request_error, "context"): + delattr(request_error, "context") + mock_delegate.request.side_effect = request_error + + with pytest.raises(RequestError, match="HTTP request failed"): + client.request(HttpMethod.POST, "https://test.com", {}) + + def test_http_code_extraction_prioritization(self, client, mock_delegate): + """Test that http-code from RequestError context is correctly extracted.""" + request_error = RequestError( + "HTTP request failed after retries", context={"http-code": 503} + ) + mock_delegate.request.side_effect = request_error + + with pytest.raises(TelemetryRateLimitError): + client.request(HttpMethod.POST, "https://test.com", {}) + + def test_non_request_error_exceptions_raised(self, client, mock_delegate): + """Test that non-RequestError exceptions are wrapped then unwrapped.""" + generic_error = ValueError("Network timeout") + mock_delegate.request.side_effect = generic_error + + with pytest.raises(ValueError, match="Network timeout"): + client.request(HttpMethod.POST, "https://test.com", {}) diff --git a/tests/unit/test_unified_http_client.py b/tests/unit/test_unified_http_client.py new file mode 100644 index 000000000..4e9ce1bbf --- /dev/null +++ b/tests/unit/test_unified_http_client.py @@ -0,0 +1,136 @@ +""" +Unit tests for UnifiedHttpClient, specifically testing MaxRetryError handling +and HTTP status code extraction. +""" + +import pytest +from unittest.mock import Mock, patch +from urllib3.exceptions import MaxRetryError + +from databricks.sql.common.unified_http_client import UnifiedHttpClient +from databricks.sql.common.http import HttpMethod +from databricks.sql.exc import RequestError +from databricks.sql.auth.common import ClientContext +from databricks.sql.types import SSLOptions + + +class TestUnifiedHttpClientMaxRetryError: + """Test MaxRetryError handling and HTTP status code extraction.""" + + @pytest.fixture + def client_context(self): + """Create a minimal ClientContext for testing.""" + context = Mock(spec=ClientContext) + context.hostname = "https://test.databricks.com" + context.ssl_options = SSLOptions( + tls_verify=True, + tls_verify_hostname=True, + tls_trusted_ca_file=None, + tls_client_cert_file=None, + tls_client_cert_key_file=None, + tls_client_cert_key_password=None, + ) + context.socket_timeout = 30 + context.retry_stop_after_attempts_count = 3 + context.retry_delay_min = 1.0 + context.retry_delay_max = 10.0 + context.retry_stop_after_attempts_duration = 300.0 + context.retry_delay_default = 5.0 + context.retry_dangerous_codes = [] + context.proxy_auth_method = None + context.pool_connections = 10 + context.pool_maxsize = 20 + context.user_agent = "test-agent" + return context + + @pytest.fixture + def http_client(self, client_context): + """Create UnifiedHttpClient instance.""" + return UnifiedHttpClient(client_context) + + @pytest.mark.parametrize("status_code,path", [ + (429, "reason.response"), + (503, "reason.response"), + (500, "direct_response"), + ]) + def test_max_retry_error_with_status_codes(self, http_client, status_code, path): + """Test MaxRetryError with various status codes and response paths.""" + mock_pool = Mock() + max_retry_error = MaxRetryError(pool=mock_pool, url="http://test.com") + + if path == "reason.response": + max_retry_error.reason = Mock() + max_retry_error.reason.response = Mock() + max_retry_error.reason.response.status = status_code + else: # direct_response + max_retry_error.response = Mock() + max_retry_error.response.status = status_code + + with patch.object( + http_client._direct_pool_manager, "request", side_effect=max_retry_error + ): + with pytest.raises(RequestError) as exc_info: + http_client.request( + HttpMethod.POST, "http://test.com", headers={"test": "header"} + ) + + error = exc_info.value + assert hasattr(error, "context") + assert "http-code" in error.context + assert error.context["http-code"] == status_code + + @pytest.mark.parametrize("setup_func", [ + lambda e: None, # No setup - error with no attributes + lambda e: setattr(e, "reason", None), # reason=None + lambda e: (setattr(e, "reason", Mock()), setattr(e.reason, "response", None)), # reason.response=None + lambda e: (setattr(e, "reason", Mock()), setattr(e.reason, "response", Mock(spec=[]))), # No status attr + ]) + def test_max_retry_error_missing_status(self, http_client, setup_func): + """Test MaxRetryError without status code (no crash, empty context).""" + mock_pool = Mock() + max_retry_error = MaxRetryError(pool=mock_pool, url="http://test.com") + setup_func(max_retry_error) + + with patch.object( + http_client._direct_pool_manager, "request", side_effect=max_retry_error + ): + with pytest.raises(RequestError) as exc_info: + http_client.request(HttpMethod.GET, "http://test.com") + + error = exc_info.value + assert error.context == {} + + def test_max_retry_error_prefers_reason_response(self, http_client): + """Test that e.reason.response.status is preferred over e.response.status.""" + mock_pool = Mock() + max_retry_error = MaxRetryError(pool=mock_pool, url="http://test.com") + + # Set both structures with different status codes + max_retry_error.reason = Mock() + max_retry_error.reason.response = Mock() + max_retry_error.reason.response.status = 429 # Should use this + + max_retry_error.response = Mock() + max_retry_error.response.status = 500 # Should be ignored + + with patch.object( + http_client._direct_pool_manager, "request", side_effect=max_retry_error + ): + with pytest.raises(RequestError) as exc_info: + http_client.request(HttpMethod.GET, "http://test.com") + + error = exc_info.value + assert error.context["http-code"] == 429 + + def test_generic_exception_no_crash(self, http_client): + """Test that generic exceptions don't crash when checking for status code.""" + generic_error = Exception("Network error") + + with patch.object( + http_client._direct_pool_manager, "request", side_effect=generic_error + ): + with pytest.raises(RequestError) as exc_info: + http_client.request(HttpMethod.POST, "http://test.com") + + error = exc_info.value + assert "HTTP request error" in str(error)