In [None]:
import asyncio
from collections.abc import Awaitable, Callable
from random import randint
from typing import Annotated
import os
from dotenv import load_dotenv
from agent_framework import FunctionInvocationContext, tool

from agent_framework import (
    AgentContext,
    AgentMiddleware,
    AgentResponse,
    Message,
    MiddlewareTermination,
    tool,
)
from agent_framework.azure import AzureOpenAIResponsesClient
from azure.identity.aio import AzureCliCredential
from pydantic import Field

In [None]:
load_dotenv(override=True)

project_endpoint = os.getenv("AZURE_AI_PROJECT_ENDPOINT")
model = os.getenv("AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME")

print("Project Endpoint: ", project_endpoint)
print("Model: ", model)

In [None]:
@tool(approval_mode="never_require")
def unstable_data_service(
    query: Annotated[str, Field(description="The data query to execute.")],
) -> str:
    """A simulated data service that sometimes throws exceptions."""
    # Simulate failure
    raise TimeoutError("Data service request timed out")

In [None]:
@tool(approval_mode="never_require")
def get_weather(
    location: Annotated[str, Field(description="The location to get the weather for.")],
) -> str:
    """Get the weather for a given location."""
    conditions = ["sunny", "cloudy", "rainy", "stormy"]
    return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}Â°C."


In [None]:
async def exception_handling_middleware(
    context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
    function_name = context.function.name

    try:
        print(f"[ExceptionHandlingMiddleware] Executing function: {function_name}")
        await call_next()
        print(f"[ExceptionHandlingMiddleware] Function {function_name} completed successfully.")
    except TimeoutError as e:
        print(f"[ExceptionHandlingMiddleware] Caught TimeoutError: {e}")
        # Override function result to provide custom message in response.
        # MUST be a string, not a tuple - otherwise API validation fails
        context.result = (
            "Request Timeout: The data service is taking longer than expected to respond. "
            "Please inform the user: 'Sorry for the inconvenience, please try again later.'"
        )


In [None]:
class PreTerminationMiddleware(AgentMiddleware):
    """MiddlewareTypes that terminates execution before calling the agent."""

    def __init__(self, blocked_words: list[str]):
        self.blocked_words = [word.lower() for word in blocked_words]

    async def process(
        self,
        context: AgentContext,
        call_next: Callable[[], Awaitable[None]],
    ) -> None:
        # Check if the user message contains any blocked words
        last_message = context.messages[-1] if context.messages else None
        
        if last_message and last_message.text:
            query = last_message.text.lower()
            
            for blocked_word in self.blocked_words:
                if blocked_word in query:
                    print(f"[PreTerminationMiddleware] Blocked word '{blocked_word}' detected. Terminating request.")

                    # Set a custom response
                    context.result = AgentResponse(
                        messages=[
                            Message(
                                role="assistant",
                                text=(
                                    f"Sorry, I cannot process requests containing '{blocked_word}'. "
                                    "Please rephrase your question."
                                ),
                            )
                        ]
                    )

                    # Terminate to prevent further processing
                    raise MiddlewareTermination(result=context.result)

        await call_next()

In [None]:
credential = AzureCliCredential()
client = AzureOpenAIResponsesClient(
    project_endpoint=project_endpoint,
    deployment_name=model,
    credential=credential,
)

agent = client.as_agent(
    name="CustomBot",
    instructions="""
        You are a helpful assistant that remembers our conversation. 
        Use the data service tool to fetch information for users.
        If the data service tool fails, respond with the error message provided by the middleware.""",
    tools=[get_weather, unstable_data_service],
    middleware=[PreTerminationMiddleware(blocked_words=["bad", "inappropriate"]), exception_handling_middleware],
)

In [None]:
# Test with normal query
print("\n1. Normal query:")
query = "What's the weather like in Seattle?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text}")

# Test with blocked word
print("\n2. Query with blocked word:")
query = "What's the bad weather in New York?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text}")

In [None]:
query = "Get user statistics"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result}")