diff --git a/sdk/python/feast/infra/offline_stores/contrib/clickhouse_offline_store/tests/data_source.py b/sdk/python/feast/infra/offline_stores/contrib/clickhouse_offline_store/tests/data_source.py index 80fd1751dc5..4234c46eb3f 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/clickhouse_offline_store/tests/data_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/clickhouse_offline_store/tests/data_source.py @@ -15,6 +15,8 @@ from feast.infra.offline_stores.contrib.clickhouse_offline_store.clickhouse_source import ( ClickhouseSource, ) +from feast.infra.utils.clickhouse.clickhouse_config import ClickhouseConfig +from feast.infra.utils.clickhouse.connection_utils import get_client from tests.integration.feature_repos.universal.data_source_creator import ( DataSourceCreator, ) @@ -114,3 +116,29 @@ def create_saved_dataset_destination(self): def teardown(self): pass + + +def test_get_client_with_additional_params(clickhouse_container): + """ + Test that get_client works with a real ClickHouse container and properly passes + additional settings like send_receive_timeout. + """ + # Create config with custom send_receive_timeout + config = ClickhouseConfig( + host=clickhouse_container.get_container_host_ip(), + port=clickhouse_container.get_exposed_port(8123), + user=CLICKHOUSE_USER, + password=CLICKHOUSE_PASSWORD, + database=CLICKHOUSE_OFFLINE_DB, + additional_client_args={"send_receive_timeout": 60}, + ) + + # Get client and verify it works + client = get_client(config) + + # Verify client is connected and functional by running a simple query + result = client.query("SELECT 1 AS test_value") + assert result.result_rows == [(1,)] + + # Verify the send_receive_timeout was applied + assert client.timeout._read == 60 diff --git a/sdk/python/feast/infra/utils/clickhouse/clickhouse_config.py b/sdk/python/feast/infra/utils/clickhouse/clickhouse_config.py index 1f163e0a81b..75167f8a60e 100644 --- a/sdk/python/feast/infra/utils/clickhouse/clickhouse_config.py +++ b/sdk/python/feast/infra/utils/clickhouse/clickhouse_config.py @@ -1,3 +1,5 @@ +from typing import Any + from pydantic import ConfigDict, StrictStr from feast.repo_config import FeastConfigBaseModel @@ -11,4 +13,8 @@ class ClickhouseConfig(FeastConfigBaseModel): password: StrictStr use_temporary_tables_for_entity_df: bool = True + # See https://github.com/ClickHouse/clickhouse-connect/blob/main/clickhouse_connect/driver/__init__.py#L51 + # Some typical ones e.g. send_receive_timeout (read_timeout), etc + additional_client_args: dict[str, Any] | None = None + model_config = ConfigDict(frozen=True) diff --git a/sdk/python/feast/infra/utils/clickhouse/connection_utils.py b/sdk/python/feast/infra/utils/clickhouse/connection_utils.py index 88f5334db14..6d5f1b87052 100644 --- a/sdk/python/feast/infra/utils/clickhouse/connection_utils.py +++ b/sdk/python/feast/infra/utils/clickhouse/connection_utils.py @@ -11,12 +11,24 @@ def get_client(config: ClickhouseConfig) -> Client: # Clickhouse client is not thread-safe, so we need to create a separate instance for each thread. if not hasattr(thread_local, "clickhouse_client"): - thread_local.clickhouse_client = clickhouse_connect.get_client( - host=config.host, - port=config.port, - user=config.user, - password=config.password, - database=config.database, - ) + additional_client_args = config.additional_client_args + + if additional_client_args: + thread_local.clickhouse_client = clickhouse_connect.get_client( + host=config.host, + port=config.port, + user=config.user, + password=config.password, + database=config.database, + **additional_client_args, + ) + else: + thread_local.clickhouse_client = clickhouse_connect.get_client( + host=config.host, + port=config.port, + user=config.user, + password=config.password, + database=config.database, + ) return thread_local.clickhouse_client diff --git a/sdk/python/tests/unit/infra/offline_stores/test_clickhouse.py b/sdk/python/tests/unit/infra/offline_stores/test_clickhouse.py index 38c632a59a7..f5440ed367d 100644 --- a/sdk/python/tests/unit/infra/offline_stores/test_clickhouse.py +++ b/sdk/python/tests/unit/infra/offline_stores/test_clickhouse.py @@ -1,3 +1,4 @@ +import logging import threading from unittest.mock import MagicMock, patch @@ -6,6 +7,8 @@ from feast.infra.utils.clickhouse.clickhouse_config import ClickhouseConfig from feast.infra.utils.clickhouse.connection_utils import get_client, thread_local +logger = logging.getLogger(__name__) + @pytest.fixture def clickhouse_config(): @@ -76,3 +79,57 @@ def thread_2_work(): assert client_1a is not client_2, ( "Different threads should get different client instances (not cached)" ) + + +def test_clickhouse_config_parses_additional_client_args(): + """ + Test that ClickhouseConfig correctly parses additional_client_args from a dict, + simulating how it would be parsed from YAML by Pydantic. + """ + # This simulates the dict that would come from yaml.safe_load() + raw_config = { + "host": "localhost", + "port": 8123, + "database": "default", + "user": "default", + "password": "password", + "additional_client_args": { + "send_receive_timeout": 60, + "compress": True, + "client_name": "feast_test", + }, + } + + # Pydantic should parse this dict into a ClickhouseConfig object + config = ClickhouseConfig(**raw_config) + + # Verify all fields are correctly parsed + assert config.host == "localhost" + assert config.port == 8123 + assert config.database == "default" + assert config.user == "default" + assert config.password == "password" + + # Verify additional_client_args is correctly parsed as a dict + assert config.additional_client_args is not None + assert isinstance(config.additional_client_args, dict) + assert config.additional_client_args["send_receive_timeout"] == 60 + assert config.additional_client_args["compress"] is True + assert config.additional_client_args["client_name"] == "feast_test" + + +def test_clickhouse_config_handles_none_additional_client_args(): + """ + Test that ClickhouseConfig correctly handles when additional_client_args is not provided. + """ + raw_config = { + "host": "localhost", + "port": 8123, + "database": "default", + "user": "default", + "password": "password", + } + + config = ClickhouseConfig(**raw_config) + + assert config.additional_client_args is None