Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions deepnote_toolkit/sql/sql_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from google.cloud import bigquery
from packaging.version import parse as parse_version
from pydantic import BaseModel
from requests.adapters import HTTPAdapter, Retry
from sqlalchemy.engine import URL, Connection, create_engine, make_url
from sqlalchemy.exc import ResourceClosedError

Expand Down Expand Up @@ -263,13 +264,28 @@ class ExecuteSqlError(Exception):
)


def _generate_temporary_credentials(integration_id):
def _create_retry_session() -> requests.Session:
"""Create a requests session with retry on 5xx for POST requests."""
session = requests.Session()
retries = Retry(
total=3,
backoff_factor=0.5,
status_forcelist=[500, 502, 503, 504],
allowed_methods=["GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS", "TRACE"],
)
Comment thread
tkislan marked this conversation as resolved.
session.mount("http://", HTTPAdapter(max_retries=retries))
session.mount("https://", HTTPAdapter(max_retries=retries))
return session
Comment thread
tkislan marked this conversation as resolved.


def _generate_temporary_credentials(integration_id) -> tuple[str, str]:
url = get_absolute_userpod_api_url(f"integrations/credentials/{integration_id}")

# Add project credentials in detached mode
headers = get_project_auth_headers()

response = requests.post(url, timeout=10, headers=headers)
session = _create_retry_session()
response = session.post(url, timeout=10, headers=headers)

response.raise_for_status()

Expand All @@ -291,7 +307,8 @@ def _get_federated_auth_credentials(
headers = get_project_auth_headers()
headers["UserPodAuthContextToken"] = user_pod_auth_context_token

response = requests.post(url, timeout=10, headers=headers)
session = _create_retry_session()
response = session.post(url, timeout=10, headers=headers)

response.raise_for_status()

Expand Down
248 changes: 244 additions & 4 deletions tests/unit/test_sql_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,9 +592,9 @@ def test_all_dataframes_serialize_to_parquet(self, key, df):
class TestFederatedAuth(unittest.TestCase):
@mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers")
@mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url")
@mock.patch("deepnote_toolkit.sql.sql_execution.requests.post")
@mock.patch("deepnote_toolkit.sql.sql_execution._create_retry_session")
def test_get_federated_auth_credentials_returns_validated_response(
self, mock_post, mock_get_url, mock_get_headers
self, mock_create_session, mock_get_url, mock_get_headers
):
"""Test that _get_federated_auth_credentials properly validates and returns response data."""
from deepnote_toolkit.sql.sql_execution import _get_federated_auth_credentials
Expand All @@ -603,12 +603,14 @@ def test_get_federated_auth_credentials_returns_validated_response(
mock_get_url.return_value = "https://api.example.com/integrations/federated-auth-token/test-integration-id"
mock_get_headers.return_value = {"Authorization": "Bearer project-token"}

mock_session = mock.Mock()
mock_response = mock.Mock()
mock_response.json.return_value = {
"integrationType": "trino",
"accessToken": "test-access-token-123",
}
mock_post.return_value = mock_response
mock_session.post.return_value = mock_response
mock_create_session.return_value = mock_session

# Call the function
result = _get_federated_auth_credentials(
Expand All @@ -621,7 +623,7 @@ def test_get_federated_auth_credentials_returns_validated_response(
)

# Verify headers include both project auth and user pod auth context token
mock_post.assert_called_once_with(
mock_session.post.assert_called_once_with(
"https://api.example.com/integrations/federated-auth-token/test-integration-id",
timeout=10,
headers={
Expand Down Expand Up @@ -1019,3 +1021,241 @@ def test_databricks_connector_dialect_alias_is_registered(self):

self.assertEqual(url.drivername, "databricks+connector")
self.assertIsNotNone(dialect_cls)


class TestCreateRetrySession(unittest.TestCase):
"""Tests that exercise the real urllib3 retry loop by mocking at the
connection level (``HTTPConnectionPool._make_request``) rather than
replacing ``_create_retry_session``. This lets the ``Retry`` adapter
actually fire retries on 5xx responses.
"""

def test_create_retry_session_configuration(self):
"""Verify the retry session is wired with the expected parameters."""
from deepnote_toolkit.sql.sql_execution import _create_retry_session

session = _create_retry_session()

for prefix in ("http://", "https://"):
adapter = session.get_adapter(f"{prefix}example.com")
retry = adapter.max_retries

self.assertEqual(retry.total, 3)
self.assertEqual(retry.backoff_factor, 0.5)
self.assertEqual(set(retry.status_forcelist), {500, 502, 503, 504})
self.assertIn("POST", retry.allowed_methods)

# -- _generate_temporary_credentials ------------------------------------

@mock.patch("urllib3.util.retry.Retry.sleep", return_value=None)
@mock.patch("urllib3.connectionpool.HTTPConnectionPool._make_request")
@mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers")
@mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url")
def test_generate_credentials_retries_on_5xx_then_succeeds(
self,
mock_get_url,
mock_get_headers,
mock_make_request,
mock_retry_sleep,
):
"""Two 5xx failures followed by a 200 - the retry loop must
transparently retry and ultimately return valid credentials."""
from urllib3 import HTTPResponse as Urllib3Response

from deepnote_toolkit.sql.sql_execution import _generate_temporary_credentials

mock_get_url.return_value = (
"https://api.example.com/integrations/credentials/test-id"
)
mock_get_headers.return_value = {"Authorization": "Bearer token"}

success_body = json.dumps({"username": "user", "password": "pass"}).encode()
mock_make_request.side_effect = [
Urllib3Response(
body=io.BytesIO(b"Internal Server Error"),
status=500,
headers={},
preload_content=False,
),
Urllib3Response(
body=io.BytesIO(b"Bad Gateway"),
status=502,
headers={},
preload_content=False,
),
Urllib3Response(
body=io.BytesIO(success_body),
status=200,
headers={"Content-Type": "application/json"},
preload_content=False,
),
]

result = _generate_temporary_credentials("test-id")

self.assertEqual(result, ("user", "pass"))
self.assertEqual(mock_make_request.call_count, 3)
self.assertEqual(mock_retry_sleep.call_count, 2)

@mock.patch("urllib3.util.retry.Retry.sleep", return_value=None)
@mock.patch("urllib3.connectionpool.HTTPConnectionPool._make_request")
@mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers")
@mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url")
def test_generate_credentials_exhausts_retries_on_persistent_5xx(
self,
mock_get_url,
mock_get_headers,
mock_make_request,
mock_retry_sleep,
):
"""All 4 attempts (1 original + 3 retries) return 500 -
must raise ``RetryError``."""
import requests
from urllib3 import HTTPResponse as Urllib3Response

from deepnote_toolkit.sql.sql_execution import _generate_temporary_credentials

mock_get_url.return_value = (
"https://api.example.com/integrations/credentials/test-id"
)
mock_get_headers.return_value = {"Authorization": "Bearer token"}

mock_make_request.side_effect = [
Urllib3Response(
body=io.BytesIO(b"Server Error"),
status=500,
headers={},
preload_content=False,
)
for _ in range(4)
]

with self.assertRaises(requests.exceptions.RetryError):
_generate_temporary_credentials("test-id")

self.assertEqual(mock_make_request.call_count, 4)
self.assertEqual(mock_retry_sleep.call_count, 3)

@mock.patch("urllib3.util.retry.Retry.sleep", return_value=None)
@mock.patch("urllib3.connectionpool.HTTPConnectionPool._make_request")
@mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers")
@mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url")
def test_generate_credentials_no_retry_on_4xx(
self,
mock_get_url,
mock_get_headers,
mock_make_request,
mock_retry_sleep,
):
"""A 400 is not in the retry status list - must fail immediately
without retrying."""
import requests
from urllib3 import HTTPResponse as Urllib3Response

from deepnote_toolkit.sql.sql_execution import _generate_temporary_credentials

mock_get_url.return_value = (
"https://api.example.com/integrations/credentials/test-id"
)
mock_get_headers.return_value = {"Authorization": "Bearer token"}

mock_make_request.side_effect = [
Urllib3Response(
body=io.BytesIO(b"Bad Request"),
status=400,
headers={},
preload_content=False,
),
]

with self.assertRaises(requests.exceptions.HTTPError):
_generate_temporary_credentials("test-id")

self.assertEqual(mock_make_request.call_count, 1)
mock_retry_sleep.assert_not_called()

# -- _get_federated_auth_credentials ------------------------------------

@mock.patch("urllib3.util.retry.Retry.sleep", return_value=None)
@mock.patch("urllib3.connectionpool.HTTPConnectionPool._make_request")
@mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers")
@mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url")
def test_federated_auth_retries_on_5xx_then_succeeds(
self,
mock_get_url,
mock_get_headers,
mock_make_request,
mock_retry_sleep,
):
"""A 503 followed by a 200 - retry loop must recover and return
valid ``FederatedAuthResponseData``."""
from urllib3 import HTTPResponse as Urllib3Response

from deepnote_toolkit.sql.sql_execution import _get_federated_auth_credentials

mock_get_url.return_value = (
"https://api.example.com/integrations/federated-auth-token/test-id"
)
mock_get_headers.return_value = {"Authorization": "Bearer token"}

success_body = json.dumps(
{"integrationType": "trino", "accessToken": "test-token"}
).encode()
mock_make_request.side_effect = [
Urllib3Response(
body=io.BytesIO(b"Service Unavailable"),
status=503,
headers={},
preload_content=False,
),
Urllib3Response(
body=io.BytesIO(success_body),
status=200,
headers={"Content-Type": "application/json"},
preload_content=False,
),
]

result = _get_federated_auth_credentials("test-id", "auth-context-token")

self.assertEqual(result.integrationType, "trino")
self.assertEqual(result.accessToken, "test-token")
self.assertEqual(mock_make_request.call_count, 2)
self.assertEqual(mock_retry_sleep.call_count, 1)

@mock.patch("urllib3.util.retry.Retry.sleep", return_value=None)
@mock.patch("urllib3.connectionpool.HTTPConnectionPool._make_request")
@mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers")
@mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url")
def test_federated_auth_exhausts_retries_on_persistent_5xx(
self,
mock_get_url,
mock_get_headers,
mock_make_request,
mock_retry_sleep,
):
"""All 4 attempts return 504 - must raise ``RetryError``."""
import requests
from urllib3 import HTTPResponse as Urllib3Response

from deepnote_toolkit.sql.sql_execution import _get_federated_auth_credentials

mock_get_url.return_value = (
"https://api.example.com/integrations/federated-auth-token/test-id"
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
mock_get_headers.return_value = {"Authorization": "Bearer token"}

mock_make_request.side_effect = [
Urllib3Response(
body=io.BytesIO(b"Gateway Timeout"),
status=504,
headers={},
preload_content=False,
)
for _ in range(4)
]

with self.assertRaises(requests.exceptions.RetryError):
_get_federated_auth_credentials("test-id", "auth-context-token")

self.assertEqual(mock_make_request.call_count, 4)
Comment thread
tkislan marked this conversation as resolved.
Loading