Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 36 additions & 7 deletions python/packages/a2a/agent_framework_a2a/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def __init__(
self._timeout_config = self._create_timeout_config(timeout)
if client is not None:
self.client = client
self._non_streaming_client: Client | None = None
self._close_http_client = True
return
if agent_card is None:
Expand All @@ -144,17 +145,30 @@ def __init__(
self._http_client = http_client # Store for cleanup
self._close_http_client = True

# Create A2A client using factory
config = ClientConfig(
interceptors = [auth_interceptor] if auth_interceptor is not None else None

# Create streaming client (SSE transport for stream=True)
streaming_config = ClientConfig(
httpx_client=http_client,
streaming=True,
supported_protocol_bindings=["JSONRPC"],
Comment thread
giles17 marked this conversation as resolved.
)
factory = ClientFactory(config)
interceptors = [auth_interceptor] if auth_interceptor is not None else None
# Create non-streaming client (single request/response for stream=False)
non_streaming_config = ClientConfig(
httpx_client=http_client,
streaming=False,
supported_protocol_bindings=["JSONRPC"],
)
streaming_factory = ClientFactory(streaming_config)
non_streaming_factory = ClientFactory(non_streaming_config)

# Attempt transport negotiation with the provided agent card
try:
self.client = factory.create(agent_card, interceptors=interceptors) # type: ignore
self.client = streaming_factory.create(agent_card, interceptors=interceptors) # type: ignore
self._non_streaming_client = non_streaming_factory.create(
agent_card,
interceptors=interceptors, # type: ignore
)
except Exception as transport_error:
# Transport negotiation failed - fall back to minimal agent card with JSONRPC
fallback_url = agent_card.supported_interfaces[0].url if agent_card.supported_interfaces else url
Expand All @@ -166,7 +180,11 @@ def __init__(
) from transport_error
fallback_card = minimal_agent_card(fallback_url, ["JSONRPC"])
try:
self.client = factory.create(fallback_card, interceptors=interceptors) # type: ignore
self.client = streaming_factory.create(fallback_card, interceptors=interceptors) # type: ignore
self._non_streaming_client = non_streaming_factory.create(
fallback_card,
interceptors=interceptors, # type: ignore
)
except Exception as fallback_error:
raise RuntimeError(
f"A2A transport negotiation failed. "
Expand Down Expand Up @@ -282,6 +300,13 @@ def run( # pyright: ignore[reportIncompatibleMethodOverride]
del function_invocation_kwargs, client_kwargs, kwargs
normalized_messages = normalize_messages(messages)

# Use non-streaming transport for non-streaming calls when available.
# This sends a single HTTP request/response instead of opening an SSE
# connection, matching the protocol's intent for synchronous operations.
active_client = (
self._non_streaming_client if (not stream and self._non_streaming_client is not None) else self.client
)

if continuation_token is not None:
a2a_stream: AsyncIterable[A2AStreamItem] = self.client.subscribe(
SubscribeToTaskRequest(id=continuation_token["task_id"])
Expand All @@ -293,7 +318,11 @@ def run( # pyright: ignore[reportIncompatibleMethodOverride]
normalized_messages[-1],
context_id=session.service_session_id if session else None,
)
a2a_stream = self.client.send_message(SendMessageRequest(message=a2a_message))
request = SendMessageRequest(message=a2a_message)
if background and not stream:
# return_immediately only applies to non-streaming (message/send)
request.configuration.return_immediately = True
Comment thread
giles17 marked this conversation as resolved.
a2a_stream = active_client.send_message(request)

provider_session = session
if provider_session is None and self.context_providers:
Expand Down
92 changes: 92 additions & 0 deletions python/packages/a2a/tests/test_a2a_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(self) -> None:
self.subscribe_responses: list[StreamResponse] = []
self.get_task_response: Task | None = None
self.last_message: Any = None
self.last_request: Any = None

def add_message_response(self, message_id: str, text: str, role: str = "agent") -> None:
"""Add a mock Message response."""
Expand Down Expand Up @@ -91,6 +92,7 @@ def add_in_progress_task_response(

async def send_message(self, request: Any) -> AsyncIterator[StreamResponse]:
"""Mock send_message method that yields responses."""
self.last_request = request
self.last_message = getattr(request, "message", request)
self.call_count += 1
Comment thread
giles17 marked this conversation as resolved.

Expand Down Expand Up @@ -745,6 +747,96 @@ async def test_working_task_no_token_without_background(a2a_agent: A2AAgent, moc
assert response.continuation_token is None


async def test_background_sets_return_immediately_on_request(
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
) -> None:
"""Test that background=True sets return_immediately=True on SendMessageRequest configuration."""
mock_a2a_client.add_in_progress_task_response("task-bg", state=TaskState.TASK_STATE_WORKING)

await a2a_agent.run("Background task", background=True)

assert mock_a2a_client.last_request.configuration.return_immediately is True


async def test_foreground_does_not_set_return_immediately(
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
) -> None:
"""Test that background=False (default) does not set configuration on SendMessageRequest."""
mock_a2a_client.add_task_response("task-fg2", [{"id": "art-1", "content": "Done"}])

await a2a_agent.run("Foreground task")

assert mock_a2a_client.last_request.HasField("configuration") is False


async def test_streaming_background_does_not_set_return_immediately(
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
) -> None:
"""Test that background=True with stream=True does not set return_immediately.

Per A2A spec, return_immediately only applies to non-streaming (message/send).
"""
mock_a2a_client.add_task_response("task-sb", [{"id": "art-1", "content": "Streaming bg"}])

updates: list[AgentResponseUpdate] = []
async for update in a2a_agent.run("Stream background", stream=True, background=True):
updates.append(update)

assert mock_a2a_client.last_request.HasField("configuration") is False


async def test_non_streaming_run_uses_non_streaming_client() -> None:
"""Test that stream=False uses the non-streaming client when available."""
streaming_client = MockA2AClient()
non_streaming_client = MockA2AClient()
non_streaming_client.add_task_response("task-ns", [{"id": "art-1", "content": "Non-streaming result"}])

Comment thread
giles17 marked this conversation as resolved.
agent = A2AAgent(name="Test Agent", id="test-ns", client=streaming_client, http_client=None)
agent._non_streaming_client = non_streaming_client # type: ignore[assignment]

response = await agent.run("Hello")

# Non-streaming client should have been called
assert non_streaming_client.call_count == 1
assert streaming_client.call_count == 0
assert response.messages[0].text == "Non-streaming result"
assert non_streaming_client.last_request.HasField("configuration") is False


async def test_streaming_run_uses_streaming_client() -> None:
"""Test that stream=True always uses the streaming client."""
streaming_client = MockA2AClient()
non_streaming_client = MockA2AClient()
streaming_client.add_task_response("task-s", [{"id": "art-1", "content": "Streaming result"}])

agent = A2AAgent(name="Test Agent", id="test-s", client=streaming_client, http_client=None)
agent._non_streaming_client = non_streaming_client # type: ignore[assignment]

updates: list[AgentResponseUpdate] = []
async for update in agent.run("Hello", stream=True):
updates.append(update)

# Streaming client should have been called
assert streaming_client.call_count == 1
assert non_streaming_client.call_count == 0
assert updates[0].contents[0].text == "Streaming result"


async def test_non_streaming_client_fallback_when_not_available(
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
) -> None:
"""Test that stream=False falls back to streaming client when non-streaming client is unavailable."""
mock_a2a_client.add_task_response("task-fb", [{"id": "art-1", "content": "Fallback result"}])

# a2a_agent is created with client= param so _non_streaming_client is None
assert a2a_agent._non_streaming_client is None

response = await a2a_agent.run("Hello")

assert mock_a2a_client.call_count == 1
assert response.messages[0].text == "Fallback result"


async def test_completed_task_has_no_continuation_token(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
"""Test that a completed task does not set a continuation token."""
mock_a2a_client.add_task_response("task-done", [{"id": "art-1", "content": "Result"}])
Expand Down
Loading