## üîê Prerequisites

Before running the first cell, make sure you're authenticated with Azure CLI:

```bash
az login
```

# üèõÔ∏è Class-Based Middleware

## Industry Use Case: Credit Limit Assessment

This notebook demonstrates how to create middleware using **classes** that implement the `AgentMiddleware` or `FunctionMiddleware` interface.

| Feature | Benefit |
|---------|---------|
| **Stateful** | Can maintain state across invocations (counters, caches) |
| **Configurable** | Constructor parameters for customization |
| **Reusable** | Instances can be shared across agents |

### FSI Scenario
A credit limit agent with class-based middleware for:
- **PII Protection**: Block and log requests with sensitive data
- **Request Counting**: Track API usage for rate limiting

In [None]:
import os
from dotenv import load_dotenv

# Load environment variables from root .env
load_dotenv('../../.env', override=True)

PROJECT_ENDPOINT = os.environ["AI_FOUNDRY_PROJECT_ENDPOINT"]
MODEL_DEPLOYMENT = os.environ.get("AZURE_AI_MODEL_DEPLOYMENT_NAME", "gpt-4o")

print(f"‚úÖ Project Endpoint: {PROJECT_ENDPOINT[:50]}...")
print(f"‚úÖ Model Deployment: {MODEL_DEPLOYMENT}")

## Import Libraries

In [None]:
import time
from collections.abc import Awaitable, Callable
from random import randint
from typing import Annotated

from agent_framework import (
    AgentMiddleware,
    AgentResponse,
    AgentRunContext,
    ChatMessage,
    FunctionInvocationContext,
    FunctionMiddleware,
    Role,
)
from agent_framework.azure import AzureAIAgentClient
from azure.identity.aio import AzureCliCredential
from pydantic import Field

print("‚úÖ Libraries imported")

## Define Credit Assessment Tools

In [None]:
def assess_credit_limit(
    customer_id: Annotated[str, Field(description="Customer ID for credit assessment")],
    requested_amount: Annotated[float, Field(description="Requested credit limit in USD")],
) -> str:
    """Assess credit limit for a customer."""
    credit_score = randint(500, 850)
    approved_limit = min(requested_amount, credit_score * 100)
    
    return f"""
    Credit Assessment Report:
    - Customer ID: {customer_id}
    - Credit Score: {credit_score}
    - Requested Limit: ${requested_amount:,.2f}
    - Approved Limit: ${approved_limit:,.2f}
    - Status: {'APPROVED' if approved_limit >= requested_amount else 'PARTIALLY APPROVED'}
    - Risk Category: {'LOW' if credit_score > 700 else 'MEDIUM' if credit_score > 600 else 'HIGH'}
    """

def get_credit_history(
    customer_id: Annotated[str, Field(description="Customer ID to retrieve history")],
) -> str:
    """Get credit history for a customer."""
    return f"""
    Credit History - Customer {customer_id}:
    - Total Accounts: {randint(2, 10)}
    - On-time Payments: {randint(90, 100)}%
    - Average Account Age: {randint(2, 15)} years
    - Recent Inquiries: {randint(0, 3)}
    """

print("‚úÖ Credit tools defined")

## Class-Based Middleware

Middleware classes that inherit from `AgentMiddleware` or `FunctionMiddleware`.

In [None]:
class PIIProtectionMiddleware(AgentMiddleware):
    """Agent middleware that blocks requests containing PII patterns."""
    
    def __init__(self, blocked_patterns: list[str] | None = None):
        """Initialize with configurable blocked patterns."""
        self.blocked_patterns = blocked_patterns or ["ssn", "social security", "password", "pin"]
        self.blocked_count = 0
    
    async def process(
        self,
        context: AgentRunContext,
        next: Callable[[AgentRunContext], Awaitable[None]],
    ) -> None:
        last_message = context.messages[-1] if context.messages else None
        if last_message and last_message.text:
            query = last_message.text.lower()
            if any(pattern in query for pattern in self.blocked_patterns):
                self.blocked_count += 1
                print(f"[üîí PII Protection] ‚ö†Ô∏è Blocked request #{self.blocked_count} - PII detected!")
                # Override with warning message
                context.result = AgentResponse(
                    messages=[
                        ChatMessage(role=Role.ASSISTANT, text="‚õî Request blocked: PII detected. Please remove sensitive information.")
                    ]
                )
                return

        print("[üîí PII Protection] ‚úÖ No PII detected.")
        await next(context)
    
    def get_stats(self) -> dict:
        """Get middleware statistics."""
        return {"blocked_requests": self.blocked_count}


class RequestCounterMiddleware(FunctionMiddleware):
    """Function middleware that counts and logs all function calls."""
    
    def __init__(self, max_requests: int = 100):
        """Initialize with configurable rate limit."""
        self.max_requests = max_requests
        self.request_count = 0
        self.total_time = 0.0
    
    async def process(
        self,
        context: FunctionInvocationContext,
        next: Callable[[FunctionInvocationContext], Awaitable[None]],
    ) -> None:
        self.request_count += 1
        function_name = context.function.name
        
        if self.request_count > self.max_requests:
            print(f"[üìä Counter] ‚ö†Ô∏è Rate limit exceeded! ({self.request_count}/{self.max_requests})")
            return
        
        print(f"[üìä Counter] Request #{self.request_count} - {function_name}")
        
        start_time = time.time()
        await next(context)
        duration = time.time() - start_time
        
        self.total_time += duration
        print(f"[üìä Counter] {function_name} completed in {duration:.3f}s")
    
    def get_stats(self) -> dict:
        """Get middleware statistics."""
        return {
            "total_requests": self.request_count,
            "total_time": f"{self.total_time:.3f}s",
            "avg_time": f"{self.total_time / max(self.request_count, 1):.3f}s"
        }

print("‚úÖ Class-based middleware defined")

## Run the Credit Assessment Demo üöÄ

In [None]:
async def main():
    """Demonstrate class-based middleware with credit assessment."""
    print("=" * 60)
    print("üèõÔ∏è CREDIT ASSESSMENT MIDDLEWARE DEMO")
    print("=" * 60)
    print()
    
    # Create reusable middleware instances
    pii_middleware = PIIProtectionMiddleware(blocked_patterns=["ssn", "social security", "password"])
    counter_middleware = RequestCounterMiddleware(max_requests=10)
    
    async with (
        AzureCliCredential() as credential,
        AzureAIAgentClient(
            project_endpoint=PROJECT_ENDPOINT,
            AZURE_AI_MODEL_DEPLOYMENT_NAME=MODEL_DEPLOYMENT,
            credential=credential,
        ).as_agent(
            name="CreditAgent",
            instructions="""You are a credit assessment assistant. Help users check credit limits,
            review credit history, and understand credit decisions. Never request SSN or other PII.""",
            tools=[assess_credit_limit, get_credit_history],
            middleware=[pii_middleware, counter_middleware],
        ) as agent,
    ):
        print("‚úÖ Agent created with class-based middleware:")
        print("   - PIIProtectionMiddleware (configurable blocked patterns)")
        print("   - RequestCounterMiddleware (tracks usage stats)")
        print()
        
        # Test 1: Normal credit check
        print("=" * 60)
        print("TEST 1: Normal Credit Assessment")
        print("=" * 60)
        query = "Assess credit limit for customer CUST-12345 requesting $25,000"
        print(f"üë§ User: {query}")
        result = await agent.run(query)
        print(f"ü§ñ Agent: {result.text if result.text else 'No response'}")
        print()
        
        # Test 2: Get credit history
        print("=" * 60)
        print("TEST 2: Credit History Request")
        print("=" * 60)
        query = "Get credit history for customer CUST-12345"
        print(f"üë§ User: {query}")
        result = await agent.run(query)
        print(f"ü§ñ Agent: {result.text if result.text else 'No response'}")
        print()
        
        # Test 3: PII test (should be blocked)
        print("=" * 60)
        print("TEST 3: PII Test (should be BLOCKED)")
        print("=" * 60)
        query = "Check credit for SSN 123-45-6789"
        print(f"üë§ User: {query}")
        result = await agent.run(query)
        print(f"ü§ñ Agent: {result.text if result.text else 'No response'}")
        print()
        
        # Show middleware stats (benefit of class-based approach)
        print("=" * 60)
        print("MIDDLEWARE STATISTICS (stateful advantage)")
        print("=" * 60)
        print(f"üìä PII Middleware: {pii_middleware.get_stats()}")
        print(f"üìä Counter Middleware: {counter_middleware.get_stats()}")
        print()
        
        print("=" * 60)
        print("‚úÖ DEMO COMPLETE")
        print("=" * 60)

await main()

## Key Takeaways üìö

### Class-Based Middleware Pattern

```python
class MyMiddleware(AgentMiddleware):  # or FunctionMiddleware
    def __init__(self, config_param: str):
        self.config = config_param
        self.counter = 0  # Stateful!
    
    async def process(self, context, next):
        self.counter += 1  # Track state
        # Pre-processing...
        await next(context)
        # Post-processing...
    
    def get_stats(self):  # Custom methods
        return {"count": self.counter}
```

### When to Use Class-Based Middleware

| Use Case | Function-Based | Class-Based |
|----------|----------------|-------------|
| Simple logging | ‚úÖ | Overkill |
| Configurable behavior | ‚ùå | ‚úÖ |
| Track statistics | ‚ùå | ‚úÖ |
| Rate limiting | ‚ùå | ‚úÖ |
| Caching | ‚ùå | ‚úÖ |

### Key Advantages
- **State management**: Track counters, caches, timings
- **Configuration**: Constructor parameters for customization
- **Reusability**: Share instances across agents
- **Testability**: Easier to unit test with dependencies