Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion src/firebolt/async_db/connection.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

from json import JSONDecodeError
from socket import IPPROTO_TCP, SO_KEEPALIVE, SOL_SOCKET, TCP_KEEPIDLE
from types import TracebackType
from typing import Callable, List, Optional, Type

from httpx import HTTPStatusError, RequestError, Timeout
from httpcore.backends.auto import AutoBackend
from httpcore.backends.base import AsyncNetworkStream
from httpx import AsyncHTTPTransport, HTTPStatusError, RequestError, Timeout

from firebolt.async_db.cursor import BaseCursor, Cursor
from firebolt.client import DEFAULT_API_URL, AsyncClient
Expand All @@ -17,6 +20,8 @@
from firebolt.common.util import fix_url_schema

DEFAULT_TIMEOUT_SECONDS: int = 5
KEEPALIVE_FLAG: int = 1
KEEPIDLE_RATE: int = 60 # seconds
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do I understand correctly, as long as this value is less than 350, the trick will work, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically yes, I was trying to find a typical expected value for this but it differs a lot. Wiki gives 45-60 range so I went with the latter.



async def _resolve_engine_url(
Expand Down Expand Up @@ -131,6 +136,36 @@ async def connect_inner(
return connect_inner


class OverriddenHttpBackend(AutoBackend):
"""
This class is a short-term solution for TCP keep-alive issue:
https://docs.aws.amazon.com/elasticloadbalancing/latest/network/network-load-balancers.html#connection-idle-timeout
Since httpx creates a connection right before executing a request
backend has to be overridden in order to set the socket KEEPALIVE
and KEEPIDLE settings.
"""

async def connect_tcp(
self,
host: str,
port: int,
timeout: Optional[float] = None,
local_address: Optional[str] = None,
) -> AsyncNetworkStream:
stream = await super().connect_tcp(
host, port, timeout=timeout, local_address=local_address
)
# Enable keepalive
stream.get_extra_info("socket").setsockopt(
SOL_SOCKET, SO_KEEPALIVE, KEEPALIVE_FLAG
)
# Set keepalive to 60 seconds
stream.get_extra_info("socket").setsockopt(
IPPROTO_TCP, TCP_KEEPIDLE, KEEPIDLE_RATE
)
return stream


class BaseConnection:
client_class: type
cursor_class: type
Expand All @@ -151,11 +186,14 @@ def __init__(
password: str,
api_endpoint: str = DEFAULT_API_URL,
):
transport = AsyncHTTPTransport()
transport._pool._network_backend = OverriddenHttpBackend()
self._client = AsyncClient(
auth=(username, password),
base_url=engine_url,
api_endpoint=api_endpoint,
timeout=Timeout(DEFAULT_TIMEOUT_SECONDS, read=None),
transport=transport,
)
self.api_endpoint = api_endpoint
self.engine_url = engine_url
Expand Down
8 changes: 8 additions & 0 deletions tests/integration/dbapi/async/test_queries_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ async def test_select(
data, all_types_query_response, "Invalid data returned by fetchmany"
)

# AWS ALB TCP timeout set to 350, make sure we handle the keepalive correctly
await c.execute(
"SELECT sleepEachRow(1) from numbers(360)",
set_parameters={"advanced_mode": "1", "use_standard_sql": "0"},
)
data = await c.fetchall()
assert len(data) == 360, "Invalid data size returned by fetchall"


@mark.asyncio
async def test_drop_create(
Expand Down
8 changes: 8 additions & 0 deletions tests/integration/dbapi/sync/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ def test_select(
data, all_types_query_response, "Invalid data returned by fetchmany"
)

# AWS ALB TCP timeout set to 350, make sure we handle the keepalive correctly
c.execute(
"SELECT sleepEachRow(1) from numbers(360)",
set_parameters={"advanced_mode": "1", "use_standard_sql": "0"},
)
data = c.fetchall()
assert len(data) == 360, "Invalid data size returned by fetchall"


def test_drop_create(
connection: Connection, create_drop_description: List[Column]
Expand Down