diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index 70235ead0d2..6070ede3c0a 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -1,10 +1,13 @@ from __future__ import annotations from json import JSONDecodeError +from socket import IPPROTO_TCP, SO_KEEPALIVE, SOL_SOCKET, TCP_KEEPIDLE from types import TracebackType from typing import Callable, List, Optional, Type -from httpx import HTTPStatusError, RequestError, Timeout +from httpcore.backends.auto import AutoBackend +from httpcore.backends.base import AsyncNetworkStream +from httpx import AsyncHTTPTransport, HTTPStatusError, RequestError, Timeout from firebolt.async_db.cursor import BaseCursor, Cursor from firebolt.client import DEFAULT_API_URL, AsyncClient @@ -17,6 +20,8 @@ from firebolt.common.util import fix_url_schema DEFAULT_TIMEOUT_SECONDS: int = 5 +KEEPALIVE_FLAG: int = 1 +KEEPIDLE_RATE: int = 60 # seconds async def _resolve_engine_url( @@ -131,6 +136,36 @@ async def connect_inner( return connect_inner +class OverriddenHttpBackend(AutoBackend): + """ + This class is a short-term solution for TCP keep-alive issue: + https://docs.aws.amazon.com/elasticloadbalancing/latest/network/network-load-balancers.html#connection-idle-timeout + Since httpx creates a connection right before executing a request + backend has to be overridden in order to set the socket KEEPALIVE + and KEEPIDLE settings. + """ + + async def connect_tcp( + self, + host: str, + port: int, + timeout: Optional[float] = None, + local_address: Optional[str] = None, + ) -> AsyncNetworkStream: + stream = await super().connect_tcp( + host, port, timeout=timeout, local_address=local_address + ) + # Enable keepalive + stream.get_extra_info("socket").setsockopt( + SOL_SOCKET, SO_KEEPALIVE, KEEPALIVE_FLAG + ) + # Set keepalive to 60 seconds + stream.get_extra_info("socket").setsockopt( + IPPROTO_TCP, TCP_KEEPIDLE, KEEPIDLE_RATE + ) + return stream + + class BaseConnection: client_class: type cursor_class: type @@ -151,11 +186,14 @@ def __init__( password: str, api_endpoint: str = DEFAULT_API_URL, ): + transport = AsyncHTTPTransport() + transport._pool._network_backend = OverriddenHttpBackend() self._client = AsyncClient( auth=(username, password), base_url=engine_url, api_endpoint=api_endpoint, timeout=Timeout(DEFAULT_TIMEOUT_SECONDS, read=None), + transport=transport, ) self.api_endpoint = api_endpoint self.engine_url = engine_url diff --git a/tests/integration/dbapi/async/test_queries_async.py b/tests/integration/dbapi/async/test_queries_async.py index 205298f45b8..a3556805bf6 100644 --- a/tests/integration/dbapi/async/test_queries_async.py +++ b/tests/integration/dbapi/async/test_queries_async.py @@ -64,6 +64,14 @@ async def test_select( data, all_types_query_response, "Invalid data returned by fetchmany" ) + # AWS ALB TCP timeout set to 350, make sure we handle the keepalive correctly + await c.execute( + "SELECT sleepEachRow(1) from numbers(360)", + set_parameters={"advanced_mode": "1", "use_standard_sql": "0"}, + ) + data = await c.fetchall() + assert len(data) == 360, "Invalid data size returned by fetchall" + @mark.asyncio async def test_drop_create( diff --git a/tests/integration/dbapi/sync/test_queries.py b/tests/integration/dbapi/sync/test_queries.py index f730405d25b..8016fccbd69 100644 --- a/tests/integration/dbapi/sync/test_queries.py +++ b/tests/integration/dbapi/sync/test_queries.py @@ -61,6 +61,14 @@ def test_select( data, all_types_query_response, "Invalid data returned by fetchmany" ) + # AWS ALB TCP timeout set to 350, make sure we handle the keepalive correctly + c.execute( + "SELECT sleepEachRow(1) from numbers(360)", + set_parameters={"advanced_mode": "1", "use_standard_sql": "0"}, + ) + data = c.fetchall() + assert len(data) == 360, "Invalid data size returned by fetchall" + def test_drop_create( connection: Connection, create_drop_description: List[Column]