diff --git a/temporalio/bridge/Cargo.lock b/temporalio/bridge/Cargo.lock index c31dafdb6..4bed007c8 100644 --- a/temporalio/bridge/Cargo.lock +++ b/temporalio/bridge/Cargo.lock @@ -2704,6 +2704,7 @@ dependencies = [ "anyhow", "async-trait", "futures", + "http", "prost", "pyo3", "pyo3-async-runtimes", diff --git a/temporalio/bridge/Cargo.toml b/temporalio/bridge/Cargo.toml index e1d375a76..8ad32478c 100644 --- a/temporalio/bridge/Cargo.toml +++ b/temporalio/bridge/Cargo.toml @@ -19,6 +19,7 @@ crate-type = ["cdylib"] anyhow = "1.0" async-trait = "0.1" futures = "0.3" +http = "1" prost = "0.14" pyo3 = { version = "0.25", features = [ "extension-module", diff --git a/temporalio/bridge/client.py b/temporalio/bridge/client.py index 564005ef0..18ce54ade 100644 --- a/temporalio/bridge/client.py +++ b/temporalio/bridge/client.py @@ -71,6 +71,7 @@ class ClientConfig: client_name: str client_version: str http_connect_proxy_config: ClientHttpConnectProxyConfig | None + override_origin: str | None @dataclass diff --git a/temporalio/bridge/src/client.rs b/temporalio/bridge/src/client.rs index 3e91f2110..17ea09314 100644 --- a/temporalio/bridge/src/client.rs +++ b/temporalio/bridge/src/client.rs @@ -1,3 +1,4 @@ +use http::Uri; use pyo3::exceptions::{PyException, PyRuntimeError, PyValueError}; use pyo3::prelude::*; use std::collections::HashMap; @@ -35,6 +36,7 @@ pub struct ClientConfig { retry_config: Option, keep_alive_config: Option, http_connect_proxy_config: Option, + override_origin: Option, } #[derive(FromPyObject)] @@ -256,6 +258,15 @@ impl ClientConfig { } else { None }) + .maybe_override_origin(if let Some(origin) = self.override_origin { + Some( + origin + .parse::() + .map_err(|err| PyValueError::new_err(format!("invalid override_origin: {err}")))?, + ) + } else { + None + }) .maybe_metrics_meter(metrics_meter); Ok(conn_opts.build()) } diff --git a/temporalio/client.py b/temporalio/client.py index f781774c1..6eaa891b0 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -139,6 +139,7 @@ async def connect( lazy: bool = False, runtime: temporalio.runtime.Runtime | None = None, http_connect_proxy_config: HttpConnectProxyConfig | None = None, + override_origin: str | None = None, header_codec_behavior: HeaderCodecBehavior = HeaderCodecBehavior.NO_CODEC, ) -> Self: """Connect to a Temporal server. @@ -194,6 +195,12 @@ async def connect( used for workers. runtime: The runtime for this client, or the default if unset. http_connect_proxy_config: Configuration for HTTP CONNECT proxy. + override_origin: If set, override the value of the gRPC ``:authority`` + pseudo-header on every RPC call. Useful when connecting through a + local proxy (e.g. an Envoy sidecar) that routes by the + ``:authority`` header, while the connection target is something like + ``localhost:``. Mirrors ``client.Options.HostPort`` in the Go + SDK and ``ConnectionOptions::override_origin`` in Rust Core. header_codec_behavior: Encoding behavior for headers sent by the client. """ connect_config = temporalio.service.ConnectConfig( @@ -207,6 +214,7 @@ async def connect( lazy=lazy, runtime=runtime, http_connect_proxy_config=http_connect_proxy_config, + override_origin=override_origin, ) def make_lambda( @@ -2825,6 +2833,7 @@ class ClientConnectConfig(TypedDict, total=False): lazy: bool runtime: temporalio.runtime.Runtime | None http_connect_proxy_config: HttpConnectProxyConfig | None + override_origin: str | None header_codec_behavior: HeaderCodecBehavior @@ -9721,6 +9730,7 @@ async def connect( lazy: bool = False, runtime: temporalio.runtime.Runtime | None = None, http_connect_proxy_config: HttpConnectProxyConfig | None = None, + override_origin: str | None = None, ) -> CloudOperationsClient: """Connect to a Temporal Cloud Operations API. @@ -9757,6 +9767,9 @@ async def connect( used for workers. runtime: The runtime for this client, or the default if unset. http_connect_proxy_config: Configuration for HTTP CONNECT proxy. + override_origin: If set, override the value of the gRPC ``:authority`` + pseudo-header on every RPC call. See :py:meth:`Client.connect` for + details. """ # Add version if given if version: @@ -9773,6 +9786,7 @@ async def connect( lazy=lazy, runtime=runtime, http_connect_proxy_config=http_connect_proxy_config, + override_origin=override_origin, ) return CloudOperationsClient( await temporalio.service.ServiceClient.connect(connect_config) diff --git a/temporalio/service.py b/temporalio/service.py index cbb3dc9be..12e5e9492 100644 --- a/temporalio/service.py +++ b/temporalio/service.py @@ -146,6 +146,7 @@ class ConnectConfig: lazy: bool = False runtime: temporalio.runtime.Runtime | None = None http_connect_proxy_config: HttpConnectProxyConfig | None = None + override_origin: str | None = None def __post_init__(self) -> None: """Set extra defaults on unset properties.""" @@ -203,6 +204,7 @@ def _to_bridge_config(self) -> temporalio.bridge.client.ClientConfig: if self.http_connect_proxy_config else None ), + override_origin=self.override_origin, ) diff --git a/tests/test_service.py b/tests/test_service.py index 9fdcd9fc7..b92e2667e 100644 --- a/tests/test_service.py +++ b/tests/test_service.py @@ -219,6 +219,91 @@ def test_connect_config_tls_explicit_config_preserved(): assert config.tls == tls_config +def test_connect_config_override_origin_forwarded_to_bridge(): + """override_origin is passed through to the bridge ClientConfig.""" + config = temporalio.service.ConnectConfig( + target_host="localhost:7233", + override_origin="http://temporal-frontend", + ) + bridge_config = config._to_bridge_config() + assert bridge_config.override_origin == "http://temporal-frontend" + + +def test_connect_config_override_origin_defaults_to_none(): + """override_origin defaults to None and is forwarded as None.""" + config = temporalio.service.ConnectConfig(target_host="localhost:7233") + bridge_config = config._to_bridge_config() + assert bridge_config.override_origin is None + + +async def test_override_origin_end_to_end(): + """override_origin reaches the gRPC channel and the client can connect. + + A minimal grpc.aio server is spun up on localhost. We connect the + Temporal service client with override_origin set to an arbitrary value. + The connection succeeds because the fake server is reachable; the + override_origin value changes the HTTP/2 :authority pseudo-header on every + call (verified by the Rust bridge unit tests). This test confirms the full + Python stack — from ConnectConfig down through the Rust bridge — does not + break when override_origin is set. + + Note: Python's grpc.aio server does not expose the :authority pseudo-header + in ServicerContext.invocation_metadata() (it is an HTTP/2 transport-level + header, not an application-level gRPC metadata entry). The per-header + behaviour is covered by the Rust Core integration tests for + ConnectionOptions::override_origin. + """ + import socket as _socket + + import grpc.aio + from temporalio.api.workflowservice.v1 import ( + request_response_pb2 as ws_pb2, + service_pb2_grpc as ws_grpc, + ) + + received_calls: list[dict] = [] + + class _FakeWorkflowService(ws_grpc.WorkflowServiceServicer): # type: ignore[misc] + async def GetSystemInfo(self, request, context): # type: ignore[override] + received_calls.append({"metadata": dict(context.invocation_metadata())}) + return ws_pb2.GetSystemInfoResponse(server_version="test") + + # Pick a free port. + with _socket.socket() as _s: + _s.bind(("127.0.0.1", 0)) + port = _s.getsockname()[1] + + server = grpc.aio.server() + ws_grpc.add_WorkflowServiceServicer_to_server(_FakeWorkflowService(), server) # type: ignore[arg-type] + server.add_insecure_port(f"127.0.0.1:{port}") + await server.start() + try: + config = temporalio.service.ConnectConfig( + target_host=f"localhost:{port}", + override_origin="http://temporal-frontend", + ) + await temporalio.service.ServiceClient.connect(config) + finally: + await server.stop(0) + + assert received_calls, "GetSystemInfo was never called — client did not connect" + # Verify that the SDK sends its standard gRPC metadata regardless of the + # override_origin setting (regression guard). + meta = received_calls[0]["metadata"] + assert "client-name" in meta + assert meta["client-name"] == "temporal-python" + + +async def test_override_origin_invalid_uri_raises(): + """An unparseable override_origin URI raises ValueError before connecting.""" + config = temporalio.service.ConnectConfig( + target_host="localhost:7233", + override_origin="this is not\na valid uri", + ) + with pytest.raises(Exception, match="invalid override_origin"): + await temporalio.service.ServiceClient.connect(config) + + async def test_rpc_execution_not_unknown(client: Client): """ Execute each rpc method and expect a failure, but ensure the failure is not that the rpc method is unknown