Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,20 @@
import asyncio
import json
import logging
import os
from collections.abc import AsyncIterable, AsyncIterator, Generator, Mapping

from agent_framework import ChatOptions, Content, HistoryProvider, Message, RawAgent, SupportsAgentRun
from agent_framework import (
ChatOptions,
Content,
ContextProvider,
FileCheckpointStorage,
HistoryProvider,
Message,
RawAgent,
SupportsAgentRun,
WorkflowAgent,
)
from agent_framework._telemetry import append_to_user_agent
from azure.ai.agentserver.responses import (
ResponseContext,
Expand Down Expand Up @@ -60,6 +71,8 @@ class ResponsesHostServer(ResponsesAgentServerHost):
"""A responses server host for an agent."""

USER_AGENT_PREFIX = "foundry-hosting"
# TODO(@taochen): Allow a different checkpoint storage that stores checkpoints externally
CHECKPOINT_STORAGE_PATH = "/.checkpoints"
Comment thread
TaoChenOSU marked this conversation as resolved.

def __init__(
self,
Expand All @@ -80,8 +93,11 @@ def __init__(
**kwargs: Additional keyword arguments.

Note:
The agent must not have a history provider with `load_messages=True`,
because history is managed by the hosting infrastructure.
1. The agent must not have a history provider with `load_messages=True`,
Comment thread
TaoChenOSU marked this conversation as resolved.
because history is managed by the hosting infrastructure.
2. The agent must not have any context providers that maintain context
in memory, because the hosting environment may get deactivated between
requests, and any in-memory context would be lost.
"""
super().__init__(prefix=prefix, options=options, store=store, **kwargs)

Expand All @@ -91,72 +107,222 @@ def __init__(
"There shouldn't be a history provider with `load_messages=True` already present. "
"History is managed by the hosting infrastructure."
)
self._agent = agent
provider = cast(ContextProvider, provider)
logger.warning(
"Context provider %s is present. If it maintains context in memory, "
"the context may be lost between requests. Use with caution.",
provider.source_id,
)

self._is_workflow_agent = False
self._checkpoint_storage_path = None
if isinstance(agent, WorkflowAgent):
if agent.workflow._runner_context.has_checkpointing(): # pyright: ignore[reportPrivateUsage]
raise RuntimeError(
"There should not be a checkpoint storage already present in the workflow agent. "
"The hosting infrastructure will manage checkpoints instead."
)
self._checkpoint_storage_path = (
self.CHECKPOINT_STORAGE_PATH
if self.config.is_hosted
else os.path.join(os.getcwd(), self.CHECKPOINT_STORAGE_PATH.lstrip("/"))
)
self._is_workflow_agent = True

self._agent = agent
self.response_handler(self._handler) # pyright: ignore[reportUnknownMemberType]

# Append the user agent prefix for telemetry purposes
append_to_user_agent(self.USER_AGENT_PREFIX)

@staticmethod
def _is_streaming_request(request: CreateResponse) -> bool:
"""Check if the request is a streaming request."""
return request.stream is not None and request.stream is True

async def _handler(
self,
request: CreateResponse,
context: ResponseContext,
cancellation_signal: asyncio.Event,
) -> AsyncIterable[ResponseStreamEvent | dict[str, Any]]:
"""Handle the creation of a response."""
if self._is_workflow_agent:
# Workflow agents are handled differently because they require checkpoint restoration
async for event in self._handle_workflow_agent(request, context, cancellation_signal):
yield event
return

input_text = await context.get_input_text()
history = await context.get_history()
messages = [*_to_messages(history), input_text]

chat_options = _to_chat_options(request)
chat_options, are_options_set = _to_chat_options(request)

stream = ResponseEventStream(response_id=context.response_id, model=request.model)
is_streaming_request = self._is_streaming_request(request)
response_event_stream = ResponseEventStream(response_id=context.response_id, model=request.model)

yield stream.emit_created()
yield stream.emit_in_progress()
yield response_event_stream.emit_created()
yield response_event_stream.emit_in_progress()

if request.stream is None or request.stream is False:
if not is_streaming_request:
# Run the agent in non-streaming mode
if isinstance(self._agent, RawAgent):
raw_agent = cast("RawAgent[Any]", self._agent) # pyright: ignore[reportUnknownMemberType]
response = await raw_agent.run(messages, stream=False, options=chat_options)
else:
if are_options_set:
logger.warning("Agent doesn't support runtime options. They will be ignored.")
response = await self._agent.run(messages, stream=False)

for message in response.messages:
for content in message.contents:
async for item in _to_outputs(stream, content):
async for item in _to_outputs(response_event_stream, content):
yield item

yield stream.emit_completed()
yield response_event_stream.emit_completed()
return

# Start the streaming response
# Run the agent in streaming mode
if isinstance(self._agent, RawAgent):
raw_agent = cast("RawAgent[Any]", self._agent) # pyright: ignore[reportUnknownMemberType]
response_stream = raw_agent.run(messages, stream=True, options=chat_options)
else:
if are_options_set:
logger.warning("Agent doesn't support runtime options. They will be ignored.")
response_stream = self._agent.run(messages, stream=True)

# Track the current active output item builder for streaming;
# lazily created on matching content, closed when a different type arrives.
tracker = _OutputItemTracker(stream)
tracker = _OutputItemTracker(response_event_stream)

async for update in response_stream:
for content in update.contents:
for event in tracker.handle(content):
yield event
if tracker.needs_async:
async for item in _to_outputs(response_event_stream, content):
yield item
tracker.needs_async = False

# Close any remaining active builder
for event in tracker.close():
yield event

yield response_event_stream.emit_completed()

async def _handle_workflow_agent(
self,
request: CreateResponse,
context: ResponseContext,
cancellation_signal: asyncio.Event,
) -> AsyncIterable[ResponseStreamEvent | dict[str, Any]]:
"""Handle the creation of a response for a workflow agent.

Why this is required:
The sandbox may be deactivated after some period of inactivity, and only data managed
by the hosting infrastructure or files will be preserved upon deactivation.
"""
input_text = await context.get_input_text()
Comment thread
TaoChenOSU marked this conversation as resolved.
is_streaming_request = self._is_streaming_request(request)

_, are_options_set = _to_chat_options(request)
if are_options_set:
logger.warning("Workflow agent doesn't support runtime options. They will be ignored.")

if request.previous_response_id is not None and context.conversation_id is not None:
raise RuntimeError("Previous response ID cannot be used in conjunction with conversation ID.")
context_id = request.previous_response_id or context.conversation_id

# The following should never happen due to the checks above.
# This is for type safety and defensive programming.
if self._checkpoint_storage_path is None:
raise RuntimeError("Checkpoint storage path is not configured for workflow agent.")
if not isinstance(self._agent, WorkflowAgent):
raise RuntimeError("Agent is not a workflow agent.")

# Restore from the latest checkpoint if available, otherwise start with an empty history
if context_id is not None:
checkpoint_storage = FileCheckpointStorage(os.path.join(self._checkpoint_storage_path, context_id))
latest_checkpoint = await checkpoint_storage.get_latest(workflow_name=self._agent.workflow.name)
if latest_checkpoint is not None:
if not is_streaming_request:
_ = await self._agent.run(
stream=False,
checkpoint_id=latest_checkpoint.checkpoint_id,
checkpoint_storage=checkpoint_storage,
)
else:
# Consume the streaming or the invocation will result in a no-op
async for _ in self._agent.run(
stream=True,
checkpoint_id=latest_checkpoint.checkpoint_id,
checkpoint_storage=checkpoint_storage,
):
pass

# Now run the agent with the latest input
response_event_stream = ResponseEventStream(response_id=context.response_id, model=request.model)

# Create a new checkpoint storage for this response based on the following rules:
# - If no previous response ID or conversation ID is provided, create a new checkpoint storage for this response
# - If a previous response ID is provided, create a new checkpoint storage for this response
# - If a conversation ID is provided, reuse the existing checkpoint storage for the conversation
context_id = context.conversation_id or context.response_id
checkpoint_storage = FileCheckpointStorage(os.path.join(self._checkpoint_storage_path, context_id))

yield response_event_stream.emit_created()
yield response_event_stream.emit_in_progress()

if not is_streaming_request:
# Run the agent in non-streaming mode
response = await self._agent.run(input_text, stream=False, checkpoint_storage=checkpoint_storage)

for message in response.messages:
for content in message.contents:
async for item in _to_outputs(response_event_stream, content):
yield item

await self._delete_not_latest_checkpoints(checkpoint_storage, self._agent.workflow.name)
yield response_event_stream.emit_completed()
return

# Run the agent in streaming mode
response_stream = self._agent.run(input_text, stream=True, checkpoint_storage=checkpoint_storage)

# Track the current active output item builder for streaming;
# lazily created on matching content, closed when a different type arrives.
tracker = _OutputItemTracker(response_event_stream)

async for update in response_stream:
for content in update.contents:
for event in tracker.handle(content):
yield event
if tracker.needs_async:
async for item in _to_outputs(stream, content):
async for item in _to_outputs(response_event_stream, content):
yield item
tracker.needs_async = False

# Close any remaining active builder
for event in tracker.close():
yield event

yield stream.emit_completed()
await self._delete_not_latest_checkpoints(checkpoint_storage, self._agent.workflow.name)
yield response_event_stream.emit_completed()
return

@staticmethod
async def _delete_not_latest_checkpoints(checkpoint_storage: FileCheckpointStorage, workflow_name: str):
"""Delete all checkpoints except the latest one.

We only need the last checkpoint for each invocation.
"""
latest_checkpoint = await checkpoint_storage.get_latest(workflow_name=workflow_name)
if latest_checkpoint is not None:
all_checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=workflow_name)
for checkpoint in all_checkpoints:
if checkpoint.checkpoint_id != latest_checkpoint.checkpoint_id:
await checkpoint_storage.delete(checkpoint.checkpoint_id)


# region Active Builder State
Expand Down Expand Up @@ -310,27 +476,34 @@ def _close(self) -> Generator[ResponseStreamEvent, None, None]:
# region Option Conversion


def _to_chat_options(request: CreateResponse) -> ChatOptions:
def _to_chat_options(request: CreateResponse) -> tuple[ChatOptions, bool]:
"""Converts a CreateResponse request to ChatOptions.

Args:
request (CreateResponse): The request to convert.

Returns:
ChatOptions: The converted ChatOptions.
bool: Whether any options were set.

"""
chat_options = ChatOptions()
are_options_set = False

if request.temperature is not None:
chat_options["temperature"] = request.temperature
are_options_set = True
if request.top_p is not None:
chat_options["top_p"] = request.top_p
are_options_set = True
if request.max_output_tokens is not None:
chat_options["max_tokens"] = request.max_output_tokens
are_options_set = True
if request.parallel_tool_calls is not None:
chat_options["allow_multiple_tool_calls"] = request.parallel_tool_calls
are_options_set = True

return chat_options
return chat_options, are_options_set


# endregion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from agent_framework import Agent
from agent_framework.foundry import FoundryChatClient
from agent_framework_foundry_hosting import ResponsesHostServer
from azure.ai.agentserver.responses import InMemoryResponseProvider
from azure.identity import AzureCliCredential
from dotenv import load_dotenv

Expand All @@ -29,7 +28,7 @@ def main():
default_options={"store": False},
)

server = ResponsesHostServer(agent, store=InMemoryResponseProvider())
server = ResponsesHostServer(agent)
server.run()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from agent_framework import Agent, tool
from agent_framework.foundry import FoundryChatClient
from agent_framework_foundry_hosting import ResponsesHostServer
from azure.ai.agentserver.responses import InMemoryResponseProvider
from azure.identity import AzureCliCredential
from dotenv import load_dotenv
from pydantic import Field
Expand Down Expand Up @@ -67,7 +66,7 @@ def main():
default_options={"store": False},
)

server = ResponsesHostServer(agent, store=InMemoryResponseProvider())
server = ResponsesHostServer(agent)
server.run()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from agent_framework import Agent, MCPStreamableHTTPTool
from agent_framework.foundry import FoundryChatClient
from agent_framework_foundry_hosting import ResponsesHostServer
from azure.ai.agentserver.responses import InMemoryResponseProvider
from azure.identity import AzureCliCredential
from dotenv import load_dotenv

Expand Down Expand Up @@ -69,7 +68,7 @@ def main():
default_options={"store": False},
)

server = ResponsesHostServer(agent, store=InMemoryResponseProvider())
server = ResponsesHostServer(agent)
server.run()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ curl -X POST http://localhost:8088/responses -H "Content-Type: application/json"
Invoke with `azd`:

```bash
azd ai agent invoke --local "List all the repositories I own on GitHub."
azd ai agent invoke --local "Create a slogan for a new electric SUV that is affordable and fun to drive."
```
Loading