From 3846b1807e0a23271a13042ed039f597247c5449 Mon Sep 17 00:00:00 2001 From: Petro Tiurin Date: Mon, 24 Jan 2022 17:25:44 +0000 Subject: [PATCH 1/5] fix: Setting TCP keepalive values --- src/firebolt/async_db/connection.py | 30 ++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index 70235ead0d2..44ca3d49474 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -1,10 +1,13 @@ from __future__ import annotations from json import JSONDecodeError +from socket import IPPROTO_TCP, SO_KEEPALIVE, SOL_SOCKET, TCP_KEEPIDLE from types import TracebackType from typing import Callable, List, Optional, Type -from httpx import HTTPStatusError, RequestError, Timeout +from httpcore.backends.auto import AutoBackend +from httpcore.backends.base import AsyncNetworkStream +from httpx import AsyncHTTPTransport, HTTPStatusError, RequestError, Timeout from firebolt.async_db.cursor import BaseCursor, Cursor from firebolt.client import DEFAULT_API_URL, AsyncClient @@ -131,6 +134,28 @@ async def connect_inner( return connect_inner +class OverriddenHttpBackend(AutoBackend): + """ + This class is a short-term solution for TCP keep-alive issue: + https://docs.aws.amazon.com/elasticloadbalancing/latest/network/network-load-balancers.html#connection-idle-timeout + Since httpx creates a connection right before executing a request + backend has to be overridden in order to set the socket KEEPALIVE + and KEEPIDLE settings. + """ + + async def connect_tcp( + self, host: str, port: int, timeout: float = None, local_address: str = None + ) -> AsyncNetworkStream: + stream = await super().connect_tcp( + host, port, timeout=timeout, local_address=local_address + ) + # Enable keepalive + stream.get_extra_info("socket").setsockopt(SOL_SOCKET, SO_KEEPALIVE, 1) + # Set keepalive to 60 seconds + stream.get_extra_info("socket").setsockopt(IPPROTO_TCP, TCP_KEEPIDLE, 60) + return stream + + class BaseConnection: client_class: type cursor_class: type @@ -151,11 +176,14 @@ def __init__( password: str, api_endpoint: str = DEFAULT_API_URL, ): + transport = AsyncHTTPTransport() + transport._pool._network_backend = OverriddenHttpBackend() self._client = AsyncClient( auth=(username, password), base_url=engine_url, api_endpoint=api_endpoint, timeout=Timeout(DEFAULT_TIMEOUT_SECONDS, read=None), + transport=transport, ) self.api_endpoint = api_endpoint self.engine_url = engine_url From 7f42d598ca2a9caf9b9a8e6d970e9625e30062ea Mon Sep 17 00:00:00 2001 From: Petro Tiurin Date: Tue, 25 Jan 2022 11:24:53 +0000 Subject: [PATCH 2/5] Adding global variables --- src/firebolt/async_db/connection.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index 44ca3d49474..76cda193d76 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -20,6 +20,8 @@ from firebolt.common.util import fix_url_schema DEFAULT_TIMEOUT_SECONDS: int = 5 +KEEPALIVE_FLAG: int = 1 +KEEPIDLE_RATE: int = 60 # seconds async def _resolve_engine_url( @@ -150,9 +152,13 @@ async def connect_tcp( host, port, timeout=timeout, local_address=local_address ) # Enable keepalive - stream.get_extra_info("socket").setsockopt(SOL_SOCKET, SO_KEEPALIVE, 1) + stream.get_extra_info("socket").setsockopt( + SOL_SOCKET, SO_KEEPALIVE, KEEPALIVE_FLAG + ) # Set keepalive to 60 seconds - stream.get_extra_info("socket").setsockopt(IPPROTO_TCP, TCP_KEEPIDLE, 60) + stream.get_extra_info("socket").setsockopt( + IPPROTO_TCP, TCP_KEEPIDLE, KEEPIDLE_RATE + ) return stream From 4e7e308761afb12af390a9a6ccf1abfa1995eba5 Mon Sep 17 00:00:00 2001 From: Petro Tiurin Date: Tue, 25 Jan 2022 15:37:08 +0000 Subject: [PATCH 3/5] Adding integration tests --- .../dbapi/async/test_queries_async.py | 8 +++ tests/integration/dbapi/sync/test_queries.py | 58 +++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/tests/integration/dbapi/async/test_queries_async.py b/tests/integration/dbapi/async/test_queries_async.py index 205298f45b8..a3556805bf6 100644 --- a/tests/integration/dbapi/async/test_queries_async.py +++ b/tests/integration/dbapi/async/test_queries_async.py @@ -64,6 +64,14 @@ async def test_select( data, all_types_query_response, "Invalid data returned by fetchmany" ) + # AWS ALB TCP timeout set to 350, make sure we handle the keepalive correctly + await c.execute( + "SELECT sleepEachRow(1) from numbers(360)", + set_parameters={"advanced_mode": "1", "use_standard_sql": "0"}, + ) + data = await c.fetchall() + assert len(data) == 360, "Invalid data size returned by fetchall" + @mark.asyncio async def test_drop_create( diff --git a/tests/integration/dbapi/sync/test_queries.py b/tests/integration/dbapi/sync/test_queries.py index f730405d25b..284b9b80269 100644 --- a/tests/integration/dbapi/sync/test_queries.py +++ b/tests/integration/dbapi/sync/test_queries.py @@ -61,6 +61,14 @@ def test_select( data, all_types_query_response, "Invalid data returned by fetchmany" ) + # AWS ALB TCP timeout set to 350, make sure we handle the keepalive correctly + c.execute( + "SELECT sleepEachRow(1) from numbers(360)", + set_parameters={"advanced_mode": "1", "use_standard_sql": "0"}, + ) + data = c.fetchall() + assert len(data) == 360, "Invalid data size returned by fetchall" + def test_drop_create( connection: Connection, create_drop_description: List[Column] @@ -269,3 +277,53 @@ def test_multi_statement_query(connection: Connection) -> None: ) assert c.nextset() is None + + +def test_aws_timeout_query(connection: Connection) -> None: + """Test that we don't timeout on queries over 350 seconds.""" + + def test_empty_query(c: Cursor, query: str) -> None: + assert c.execute(query) == -1, "Invalid row count returned" + assert c.rowcount == -1, "Invalid rowcount value" + assert c.description is None, "Invalid description" + with raises(DataError): + c.fetchone() + + with raises(DataError): + c.fetchmany() + + with raises(DataError): + c.fetchall() + + with connection.cursor() as c: + c.execute("DROP TABLE IF EXISTS test_tb") + c.execute( + "CREATE FACT TABLE test_tb(id int, sn string null, f float," + "d date, dt datetime, b bool, a array(int)) primary index id" + ) + + test_empty_query( + c, + "INSERT INTO test_tb VALUES (1, 'sn', 1.1, '2021-01-01'," + "'2021-01-01 01:01:01', true, [1, 2, 3])", + ) + + assert ( + c.execute("SELECT * FROM test_tb ORDER BY test_tb.id") == 1 + ), "Invalid data length in table after insert" + + assert_deep_eq( + c.fetchall(), + [ + [ + 1, + "sn", + 1.1, + date(2021, 1, 1), + datetime(2021, 1, 1, 1, 1, 1), + 1, + [1, 2, 3], + ], + ], + "Invalid data in table after insert", + ) From ce10e4de2e51d7a50a069899d8a97372e35ac27e Mon Sep 17 00:00:00 2001 From: Petro Tiurin Date: Tue, 25 Jan 2022 15:41:31 +0000 Subject: [PATCH 4/5] Fix typehints --- src/firebolt/async_db/connection.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index 76cda193d76..6070ede3c0a 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -146,7 +146,11 @@ class OverriddenHttpBackend(AutoBackend): """ async def connect_tcp( - self, host: str, port: int, timeout: float = None, local_address: str = None + self, + host: str, + port: int, + timeout: Optional[float] = None, + local_address: Optional[str] = None, ) -> AsyncNetworkStream: stream = await super().connect_tcp( host, port, timeout=timeout, local_address=local_address From 7d8219822d7a40f91857b1efedc072b1741f2276 Mon Sep 17 00:00:00 2001 From: Petro Tiurin Date: Tue, 25 Jan 2022 15:55:33 +0000 Subject: [PATCH 5/5] Remove duplication --- tests/integration/dbapi/sync/test_queries.py | 50 -------------------- 1 file changed, 50 deletions(-) diff --git a/tests/integration/dbapi/sync/test_queries.py b/tests/integration/dbapi/sync/test_queries.py index 284b9b80269..8016fccbd69 100644 --- a/tests/integration/dbapi/sync/test_queries.py +++ b/tests/integration/dbapi/sync/test_queries.py @@ -277,53 +277,3 @@ def test_multi_statement_query(connection: Connection) -> None: ) assert c.nextset() is None - - -def test_aws_timeout_query(connection: Connection) -> None: - """Test that we don't timeout on queries over 350 seconds.""" - - def test_empty_query(c: Cursor, query: str) -> None: - assert c.execute(query) == -1, "Invalid row count returned" - assert c.rowcount == -1, "Invalid rowcount value" - assert c.description is None, "Invalid description" - with raises(DataError): - c.fetchone() - - with raises(DataError): - c.fetchmany() - - with raises(DataError): - c.fetchall() - - with connection.cursor() as c: - c.execute("DROP TABLE IF EXISTS test_tb") - c.execute( - "CREATE FACT TABLE test_tb(id int, sn string null, f float," - "d date, dt datetime, b bool, a array(int)) primary index id" - ) - - test_empty_query( - c, - "INSERT INTO test_tb VALUES (1, 'sn', 1.1, '2021-01-01'," - "'2021-01-01 01:01:01', true, [1, 2, 3])", - ) - - assert ( - c.execute("SELECT * FROM test_tb ORDER BY test_tb.id") == 1 - ), "Invalid data length in table after insert" - - assert_deep_eq( - c.fetchall(), - [ - [ - 1, - "sn", - 1.1, - date(2021, 1, 1), - datetime(2021, 1, 1, 1, 1, 1), - 1, - [1, 2, 3], - ], - ], - "Invalid data in table after insert", - )