# Class-Based Middleware

This notebook demonstrates how to create middleware using classes that implement the `AgentMiddleware` interface.

## Key Concepts

- **Stateful Middleware**: Classes can maintain state across invocations
- **Multiple Hooks**: Can override multiple lifecycle methods
- **Reusability**: Class instances can be reused across agents

## Prerequisites

- Azure OpenAI endpoint and deployment configured in `.env`
- `agent-framework` package installed
- Azure CLI authentication

In [None]:
import os
from dotenv import load_dotenv

load_dotenv('../.env')

project_endpoint = os.getenv("AZURE_AI_PROJECT_ENDPOINT")
model_deployment_name = os.getenv("AZURE_AI_MODEL_DEPLOYMENT_NAME")

print(f"Project Endpoint: {project_endpoint}")
print(f"Model Deployment: {model_deployment_name}")

In [None]:
import time
from collections.abc import Awaitable, Callable
from random import randint
from typing import Annotated

from agent_framework import (
    AgentMiddleware,
    AgentRunContext,
    AgentRunResponse,
    ChatMessage,
    FunctionInvocationContext,
    FunctionMiddleware,
    Role,
)
from agent_framework.azure import AzureAIAgentClient
from azure.identity.aio import AzureCliCredential
from pydantic import Field

## Define Tool Function

In [None]:
def get_weather(
    location: Annotated[str, Field(description="The location to get the weather for.")],
) -> str:
    """Get the weather for a given location."""
    conditions = ["sunny", "cloudy", "rainy", "stormy"]
    return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}Â°C."

## Class-Based Middleware Examples

This demonstrates two types of class-based middleware:
- **SecurityAgentMiddleware**: Checks for security violations and blocks sensitive requests
- **LoggingFunctionMiddleware**: Logs function execution with timing information

In [None]:
class SecurityAgentMiddleware(AgentMiddleware):
    """Agent middleware that checks for security violations."""

    async def process(
        self,
        context: AgentRunContext,
        next: Callable[[AgentRunContext], Awaitable[None]],
    ) -> None:
        # Check for potential security violations in the query
        last_message = context.messages[-1] if context.messages else None
        if last_message and last_message.text:
            query = last_message.text
            if "password" in query.lower() or "secret" in query.lower():
                print("[SecurityMiddleware] Security Warning: Detected sensitive information, blocking request.")
                # Override the result with warning message
                context.result = AgentRunResponse(
                    messages=[
                        ChatMessage(role=Role.ASSISTANT, text="Detected sensitive information, the request is blocked.")
                    ]
                )
                # Don't call next() to prevent execution
                return

        print("[SecurityMiddleware] Security check passed.")
        await next(context)


class LoggingFunctionMiddleware(FunctionMiddleware):
    """Function middleware that logs function calls with timing."""

    async def process(
        self,
        context: FunctionInvocationContext,
        next: Callable[[FunctionInvocationContext], Awaitable[None]],
    ) -> None:
        function_name = context.function.name
        print(f"[LoggingMiddleware] About to call function: {function_name}")

        start_time = time.time()
        await next(context)
        end_time = time.time()
        
        duration = end_time - start_time
        print(f"[LoggingMiddleware] Function {function_name} completed in {duration:.5f}s")

## Example: Using Class-Based Middleware

In [None]:
async def main():
    """Example demonstrating class-based middleware."""
    print("=== Class-based Middleware Example ===\n")
    
    async with (
        AzureCliCredential() as credential,
        AzureAIAgentClient(async_credential=credential).create_agent(
            name="WeatherAgent",
            instructions="You are a helpful weather assistant.",
            tools=get_weather,
            middleware=[SecurityAgentMiddleware(), LoggingFunctionMiddleware()],
        ) as agent,
    ):
        # Test with normal query
        print("--- Normal Query ---")
        query = "What's the weather like in Seattle?"
        print(f"User: {query}")
        result = await agent.run(query)
        print(f"Agent: {result.text}\n")

        # Test with security-related query
        print("--- Security Test ---")
        query = "What's the password for the weather service?"
        print(f"User: {query}")
        result = await agent.run(query)
        print(f"Agent: {result.text}")

await main()

## Key Takeaways

- Class-based middleware can maintain state (like counters, caches, etc.)
- Implement `AgentMiddleware` and override lifecycle methods
- Instances can be reused across multiple agents
- Great for complex middleware that needs initialization or cleanup

## Next Steps

- **Decorator middleware** (notebook 4) for simplified syntax
- **Exception handling** (notebook 6) for robust error management