# Class-based Middlewares for HR Agents - LangChain 1.0

**Module:** Class-based Middleware Architecture

**What you'll learn:**
- 🏗️ Building reusable middleware classes
- 🔄 Stateful middleware patterns
- 🎯 Complex middleware with initialization
- 🔌 Composable middleware architecture
- 🏭 Production-ready middleware design

**HR Use Cases:**
- Multi-tenant HR systems
- Complex approval workflows
- Advanced auditing and compliance
- Configurable business rules
- Enterprise-grade security

**Time:** 2-3 hours

---

## Setup

In [None]:
!pip install --pre -U langchain langchain-openai langgraph
!pip install langgraph-checkpoint-sqlite

In [None]:
from google.colab import userdata
import os

os.environ['OPENAI_API_KEY'] = userdata.get('OPENAI_API_KEY')

from langchain.agents import create_agent, AgentState
from langchain_openai import ChatOpenAI
from langchain_core.tools import tool
from langgraph.checkpoint.memory import InMemorySaver
from typing import Annotated, Optional, Dict, List, Any
from datetime import datetime, timedelta
from abc import ABC, abstractmethod
from enum import Enum
import json
import hashlib

print("✅ Setup complete!")

## Setup: HR Data

In [None]:
# Employee database
EMPLOYEES = {
    "101": {
        "name": "Priya Sharma",
        "department": "Engineering",
        "role": "Senior Developer",
        "salary": 120000,
        "level": "L4",
        "manager_id": "102",
        "permissions": ["read", "write"]
    },
    "102": {
        "name": "Rahul Verma",
        "department": "Engineering",
        "role": "Engineering Manager",
        "salary": 180000,
        "level": "L6",
        "manager_id": "103",
        "permissions": ["read", "write", "approve"]
    },
    "103": {
        "name": "Anjali Patel",
        "department": "HR",
        "role": "HR Director",
        "salary": 200000,
        "level": "L7",
        "manager_id": None,
        "permissions": ["read", "write", "approve", "admin"]
    },
    "104": {
        "name": "Arjun Reddy",
        "department": "Sales",
        "role": "Sales Team Lead",
        "salary": 150000,
        "level": "L5",
        "manager_id": "105",
        "permissions": ["read", "write"]
    },
    "105": {
        "name": "Sneha Gupta",
        "department": "Sales",
        "role": "Sales Director",
        "salary": 190000,
        "level": "L7",
        "manager_id": "103",
        "permissions": ["read", "write", "approve"]
    }
}

# HR Tools
@tool
def get_employee_info(employee_id: Annotated[str, "Employee ID"]) -> str:
    """Get employee information."""
    if employee_id in EMPLOYEES:
        emp = EMPLOYEES[employee_id]
        return f"{emp['name']} - {emp['department']} - {emp['role']} (Level {emp['level']})"
    return "Employee not found"

@tool
def check_salary(employee_id: Annotated[str, "Employee ID"]) -> str:
    """Check salary. SENSITIVE."""
    if employee_id in EMPLOYEES:
        return f"Salary: ₹{EMPLOYEES[employee_id]['salary']:,}"
    return "Not found"

@tool
def update_salary(employee_id: Annotated[str, "Employee ID"], new_salary: Annotated[int, "New salary"]) -> str:
    """Update salary. CRITICAL operation."""
    if employee_id in EMPLOYEES:
        old = EMPLOYEES[employee_id]['salary']
        EMPLOYEES[employee_id]['salary'] = new_salary
        return f"Updated: ₹{old:,} → ₹{new_salary:,}"
    return "Failed"

@tool
def terminate_employee(employee_id: Annotated[str, "Employee ID"], reason: Annotated[str, "Termination reason"]) -> str:
    """Terminate employee. CRITICAL operation."""
    if employee_id in EMPLOYEES:
        return f"TERMINATION NOTICE: {EMPLOYEES[employee_id]['name']} - Reason: {reason}"
    return "Failed"

print(f"✅ Loaded {len(EMPLOYEES)} employees")

---
# Part 1: Base Middleware Class

## Abstract Base Class Pattern

**Benefits:**
- Enforces consistent interface
- Enables polymorphism
- Provides common functionality
- Simplifies testing

---

## Lab 1.1: Define Base Middleware Class

In [None]:
class BaseMiddleware(ABC):
    """Abstract base class for all middlewares."""
    
    def __init__(self, name: str = None, enabled: bool = True):
        self.name = name or self.__class__.__name__
        self.enabled = enabled
        self.stats = {
            "pre_calls": 0,
            "post_calls": 0,
            "errors": 0
        }
    
    def pre_hook(self, state: AgentState) -> dict:
        """Pre-model hook. Override if needed."""
        if not self.enabled:
            return {}
        
        try:
            self.stats["pre_calls"] += 1
            return self._pre_process(state)
        except Exception as e:
            self.stats["errors"] += 1
            return self._handle_error("pre_hook", e, state)
    
    def post_hook(self, state: AgentState) -> dict:
        """Post-model hook. Override if needed."""
        if not self.enabled:
            return {}
        
        try:
            self.stats["post_calls"] += 1
            return self._post_process(state)
        except Exception as e:
            self.stats["errors"] += 1
            return self._handle_error("post_hook", e, state)
    
    @abstractmethod
    def _pre_process(self, state: AgentState) -> dict:
        """Implement pre-processing logic."""
        pass
    
    @abstractmethod
    def _post_process(self, state: AgentState) -> dict:
        """Implement post-processing logic."""
        pass
    
    def _handle_error(self, hook_name: str, error: Exception, state: AgentState) -> dict:
        """Handle errors gracefully."""
        print(f"❌ [{self.name}] Error in {hook_name}: {error}")
        return {}
    
    def enable(self):
        """Enable this middleware."""
        self.enabled = True
    
    def disable(self):
        """Disable this middleware."""
        self.enabled = False
    
    def get_stats(self) -> dict:
        """Get middleware statistics."""
        return {
            "name": self.name,
            "enabled": self.enabled,
            **self.stats
        }
    
    def __repr__(self):
        status = "✅" if self.enabled else "❌"
        return f"{status} {self.name}"

print("✅ BaseMiddleware class defined!")

---
# Part 2: Authentication & Authorization Middleware

## Multi-Level Security

In [None]:
class Permission(Enum):
    """Permission levels."""
    READ = "read"
    WRITE = "write"
    APPROVE = "approve"
    ADMIN = "admin"

class AuthenticationMiddleware(BaseMiddleware):
    """Verify user authentication and load profile."""
    
    def __init__(self):
        super().__init__("Authentication")
        self.session_cache = {}  # user_id -> session_data
    
    def _pre_process(self, state: AgentState) -> dict:
        """Authenticate user and load profile."""
        user_id = state.get("current_user_id")
        
        print(f"\n🔐 [Authentication] Checking user: {user_id}")
        
        # Check if authenticated
        if not user_id or user_id not in EMPLOYEES:
            print(f"   ❌ Authentication failed")
            return {
                "authenticated": False,
                "messages": [("assistant", "❌ Authentication required. Invalid user ID.")],
                "jump_to": "__end__"
            }
        
        # Load user profile
        user_profile = EMPLOYEES[user_id]
        
        # Create/update session
        if user_id not in self.session_cache:
            self.session_cache[user_id] = {
                "first_seen": datetime.now(),
                "request_count": 0
            }
        
        self.session_cache[user_id]["request_count"] += 1
        self.session_cache[user_id]["last_seen"] = datetime.now()
        
        print(f"   ✅ Authenticated: {user_profile['name']}")
        print(f"   📋 Role: {user_profile['role']}")
        print(f"   🎫 Permissions: {', '.join(user_profile['permissions'])}")
        print(f"   📊 Session requests: {self.session_cache[user_id]['request_count']}")
        
        return {
            "authenticated": True,
            "user_profile": user_profile,
            "user_permissions": user_profile["permissions"]
        }
    
    def _post_process(self, state: AgentState) -> dict:
        """No post-processing needed for authentication."""
        return {}

class AuthorizationMiddleware(BaseMiddleware):
    """Check permissions for requested operations."""
    
    def __init__(self, tool_permissions: Dict[str, List[str]]):
        super().__init__("Authorization")
        self.tool_permissions = tool_permissions
    
    def _pre_process(self, state: AgentState) -> dict:
        """Check authorization before processing."""
        # Only check if user is authenticated
        if not state.get("authenticated"):
            return {}
        
        messages = state.get("messages", [])
        if not messages:
            return {}
        
        user_permissions = state.get("user_permissions", [])
        content = messages[-1].content.lower()
        
        # Check if requesting sensitive operations
        for tool_name, required_perms in self.tool_permissions.items():
            if tool_name.lower() in content:
                print(f"\n🔒 [Authorization] Operation: {tool_name}")
                print(f"   Required: {', '.join(required_perms)}")
                print(f"   User has: {', '.join(user_permissions)}")
                
                # Check permissions
                if not all(perm in user_permissions for perm in required_perms):
                    missing = [p for p in required_perms if p not in user_permissions]
                    print(f"   ❌ Missing permissions: {', '.join(missing)}")
                    
                    return {
                        "authorized": False,
                        "messages": [(
                            "assistant",
                            f"❌ Access Denied: You don't have permission to perform '{tool_name}'.\n\n"
                            f"Required: {', '.join(required_perms)}\n"
                            f"You have: {', '.join(user_permissions)}"
                        )],
                        "jump_to": "__end__"
                    }
                
                print(f"   ✅ Authorized")
        
        return {"authorized": True}
    
    def _post_process(self, state: AgentState) -> dict:
        return {}

print("✅ Auth middlewares defined!")

---
# Part 3: Audit & Compliance Middleware

## Enterprise-Grade Auditing

In [None]:
class AuditLevel(Enum):
    """Audit logging levels."""
    INFO = "INFO"
    WARNING = "WARNING"
    CRITICAL = "CRITICAL"

class ComplianceAuditMiddleware(BaseMiddleware):
    """Comprehensive audit trail for compliance."""
    
    def __init__(self, critical_operations: List[str]):
        super().__init__("ComplianceAudit")
        self.critical_operations = critical_operations
        self.audit_log = []
        self.current_request = {}
    
    def _pre_process(self, state: AgentState) -> dict:
        """Log request details."""
        messages = state.get("messages", [])
        if not messages:
            return {}
        
        user_profile = state.get("user_profile", {})
        content = messages[-1].content
        
        # Determine audit level
        level = AuditLevel.INFO
        for critical_op in self.critical_operations:
            if critical_op.lower() in content.lower():
                level = AuditLevel.CRITICAL
                break
        
        # Store for post-processing
        self.current_request = {
            "request_id": f"req_{len(self.audit_log) + 1}",
            "timestamp": datetime.now().isoformat(),
            "user_id": state.get("current_user_id"),
            "user_name": user_profile.get("name", "Unknown"),
            "department": user_profile.get("department", "Unknown"),
            "role": user_profile.get("role", "Unknown"),
            "request": content,
            "level": level.value,
            "ip_address": "192.168.1.100",  # Mock
            "session_id": state.get("session_id", "unknown")
        }
        
        level_icon = "🔴" if level == AuditLevel.CRITICAL else "🔵"
        print(f"\n{level_icon} [Audit-{level.value}] Request logged")
        print(f"   ID: {self.current_request['request_id']}")
        print(f"   User: {user_profile.get('name', 'Unknown')}")
        print(f"   Operation: {content[:60]}...")
        
        return {}
    
    def _post_process(self, state: AgentState) -> dict:
        """Log response and finalize audit entry."""
        if not self.current_request:
            return {}
        
        messages = state.get("messages", [])
        response = messages[-1].content if messages else "No response"
        
        # Check for tool usage
        tools_used = []
        for msg in messages:
            if hasattr(msg, 'tool_calls') and msg.tool_calls:
                tools_used.extend([tc.get('name') for tc in msg.tool_calls])
        
        # Complete audit entry
        self.current_request.update({
            "response": response[:200],
            "tools_used": tools_used,
            "completed_at": datetime.now().isoformat(),
            "success": not state.get("jump_to") == "__end__"
        })
        
        self.audit_log.append(self.current_request)
        
        print(f"   📋 Audit entry saved: {self.current_request['request_id']}")
        if tools_used:
            print(f"   🔧 Tools used: {', '.join(tools_used)}")
        
        self.current_request = {}
        return {}
    
    def export_audit_log(self, level: AuditLevel = None) -> List[dict]:
        """Export audit log, optionally filtered by level."""
        if level:
            return [entry for entry in self.audit_log if entry['level'] == level.value]
        return self.audit_log
    
    def generate_compliance_report(self) -> dict:
        """Generate compliance report."""
        critical_ops = [e for e in self.audit_log if e['level'] == AuditLevel.CRITICAL.value]
        
        # User activity
        user_activity = {}
        for entry in self.audit_log:
            user = entry['user_id']
            user_activity[user] = user_activity.get(user, 0) + 1
        
        return {
            "report_generated": datetime.now().isoformat(),
            "total_requests": len(self.audit_log),
            "critical_operations": len(critical_ops),
            "user_activity": user_activity,
            "top_users": sorted(user_activity.items(), key=lambda x: x[1], reverse=True)[:5]
        }

print("✅ Audit middleware defined!")

---
# Part 4: Rate Limiting & Throttling Middleware

## Configurable Rate Limits

In [None]:
class RateLimitStrategy(Enum):
    """Rate limiting strategies."""
    FIXED_WINDOW = "fixed_window"
    SLIDING_WINDOW = "sliding_window"
    TOKEN_BUCKET = "token_bucket"

class AdvancedRateLimitMiddleware(BaseMiddleware):
    """Advanced rate limiting with multiple strategies."""
    
    def __init__(
        self,
        max_requests: int = 10,
        window_seconds: int = 60,
        strategy: RateLimitStrategy = RateLimitStrategy.SLIDING_WINDOW,
        per_user: bool = True
    ):
        super().__init__("RateLimit")
        self.max_requests = max_requests
        self.window_seconds = window_seconds
        self.strategy = strategy
        self.per_user = per_user
        self.request_history = {}  # user_id -> [timestamps]
        self.violations = []  # Track violations
    
    def _pre_process(self, state: AgentState) -> dict:
        """Check rate limit."""
        key = state.get("current_user_id", "global") if self.per_user else "global"
        now = datetime.now()
        
        # Initialize history
        if key not in self.request_history:
            self.request_history[key] = []
        
        # Clean old requests (sliding window)
        cutoff = now - timedelta(seconds=self.window_seconds)
        self.request_history[key] = [
            ts for ts in self.request_history[key] if ts > cutoff
        ]
        
        current_count = len(self.request_history[key])
        
        print(f"\n⏱️  [Rate Limit] {self.strategy.value}")
        print(f"   User: {key}")
        print(f"   Window: {self.window_seconds}s")
        print(f"   Requests: {current_count}/{self.max_requests}")
        
        # Check limit
        if current_count >= self.max_requests:
            # Log violation
            violation = {
                "timestamp": now.isoformat(),
                "user_id": key,
                "count": current_count,
                "limit": self.max_requests
            }
            self.violations.append(violation)
            
            # Calculate retry time
            oldest_request = min(self.request_history[key])
            retry_after = self.window_seconds - (now - oldest_request).total_seconds()
            
            print(f"   ❌ RATE LIMIT EXCEEDED")
            print(f"   🔒 Retry after: {retry_after:.0f}s")
            
            return {
                "rate_limited": True,
                "messages": [(
                    "assistant",
                    f"⏱️ Rate Limit Exceeded\n\n"
                    f"You have made {current_count} requests in the last {self.window_seconds} seconds.\n"
                    f"Maximum allowed: {self.max_requests}\n\n"
                    f"Please try again in {retry_after:.0f} seconds."
                )],
                "jump_to": "__end__"
            }
        
        # Allow request
        self.request_history[key].append(now)
        remaining = self.max_requests - (current_count + 1)
        
        print(f"   ✅ Request allowed")
        print(f"   📊 Remaining: {remaining}")
        
        return {"rate_limited": False}
    
    def _post_process(self, state: AgentState) -> dict:
        return {}
    
    def get_violations(self) -> List[dict]:
        """Get rate limit violations."""
        return self.violations
    
    def reset_user_limit(self, user_id: str):
        """Reset rate limit for a specific user."""
        if user_id in self.request_history:
            del self.request_history[user_id]
            print(f"✅ Rate limit reset for user: {user_id}")

print("✅ Rate limit middleware defined!")

---
# Part 5: Caching Middleware

## Intelligent Response Caching

In [None]:
class CacheStrategy(Enum):
    """Caching strategies."""
    LRU = "lru"  # Least Recently Used
    LFU = "lfu"  # Least Frequently Used
    TTL = "ttl"  # Time To Live

class SmartCachingMiddleware(BaseMiddleware):
    """Intelligent caching with multiple strategies."""
    
    def __init__(
        self,
        max_size: int = 100,
        ttl_seconds: int = 300,
        strategy: CacheStrategy = CacheStrategy.TTL
    ):
        super().__init__("SmartCache")
        self.max_size = max_size
        self.ttl_seconds = ttl_seconds
        self.strategy = strategy
        self.cache = {}  # query_hash -> cache_entry
        self.hits = 0
        self.misses = 0
    
    def _hash_query(self, query: str) -> str:
        """Generate hash for query."""
        return hashlib.md5(query.lower().strip().encode()).hexdigest()
    
    def _pre_process(self, state: AgentState) -> dict:
        """Check cache before LLM call."""
        messages = state.get("messages", [])
        if not messages:
            return {}
        
        query = messages[-1].content
        query_hash = self._hash_query(query)
        
        print(f"\n💾 [Cache] Checking cache...")
        print(f"   Strategy: {self.strategy.value}")
        print(f"   Query hash: {query_hash[:12]}...")
        
        # Check cache
        if query_hash in self.cache:
            entry = self.cache[query_hash]
            
            # Check TTL
            age = (datetime.now() - entry['timestamp']).total_seconds()
            if age < self.ttl_seconds:
                # Cache hit!
                self.hits += 1
                entry['hits'] += 1
                entry['last_accessed'] = datetime.now()
                
                hit_rate = (self.hits / (self.hits + self.misses)) * 100
                
                print(f"   ✅ CACHE HIT!")
                print(f"   ⚡ Saved LLM call")
                print(f"   📊 Hit rate: {hit_rate:.1f}%")
                print(f"   🕐 Cache age: {age:.0f}s")
                
                return {
                    "messages": [("assistant", entry['response'])],
                    "jump_to": "__end__",
                    "cache_hit": True
                }
            else:
                # Expired
                print(f"   ⏰ Cache expired (age: {age:.0f}s)")
                del self.cache[query_hash]
        
        # Cache miss
        self.misses += 1
        print(f"   ❌ Cache miss")
        print(f"   📊 Cache size: {len(self.cache)}/{self.max_size}")
        
        return {"_query_hash": query_hash, "cache_hit": False}
    
    def _post_process(self, state: AgentState) -> dict:
        """Store response in cache."""
        # Skip if cache hit
        if state.get("cache_hit"):
            return {}
        
        query_hash = state.get("_query_hash")
        if not query_hash:
            return {}
        
        messages = state.get("messages", [])
        if not messages:
            return {}
        
        response = messages[-1].content
        
        # Check size limit
        if len(self.cache) >= self.max_size:
            self._evict_entry()
        
        # Store in cache
        self.cache[query_hash] = {
            "response": response,
            "timestamp": datetime.now(),
            "last_accessed": datetime.now(),
            "hits": 0
        }
        
        print(f"\n💾 [Cache] Response cached")
        print(f"   Hash: {query_hash[:12]}...")
        print(f"   Size: {len(self.cache)}/{self.max_size}")
        
        return {}
    
    def _evict_entry(self):
        """Evict entry based on strategy."""
        if self.strategy == CacheStrategy.LRU:
            # Remove least recently accessed
            oldest = min(self.cache.items(), key=lambda x: x[1]['last_accessed'])
            del self.cache[oldest[0]]
        elif self.strategy == CacheStrategy.LFU:
            # Remove least frequently used
            least_used = min(self.cache.items(), key=lambda x: x[1]['hits'])
            del self.cache[least_used[0]]
        else:  # TTL - remove oldest
            oldest = min(self.cache.items(), key=lambda x: x[1]['timestamp'])
            del self.cache[oldest[0]]
    
    def get_cache_stats(self) -> dict:
        """Get cache statistics."""
        total = self.hits + self.misses
        hit_rate = (self.hits / total * 100) if total > 0 else 0
        
        return {
            "hits": self.hits,
            "misses": self.misses,
            "hit_rate": f"{hit_rate:.1f}%",
            "cache_size": len(self.cache),
            "max_size": self.max_size,
            "strategy": self.strategy.value
        }
    
    def clear_cache(self):
        """Clear all cache entries."""
        self.cache.clear()
        print("✅ Cache cleared")

print("✅ Caching middleware defined!")

---
# Part 6: Middleware Orchestrator

## Composing Multiple Middlewares

In [None]:
class MiddlewareOrchestrator:
    """Orchestrate multiple middlewares."""
    
    def __init__(self, middlewares: List[BaseMiddleware]):
        self.middlewares = middlewares
    
    def pre_hook(self, state: AgentState) -> dict:
        """Run all pre-hooks in order."""
        updates = {}
        
        for middleware in self.middlewares:
            if not middleware.enabled:
                continue
            
            result = middleware.pre_hook(state)
            
            # Merge updates
            updates.update(result)
            
            # Update state for next middleware
            state.update(result)
            
            # Early exit if jump requested
            if result.get("jump_to"):
                return updates
        
        return updates
    
    def post_hook(self, state: AgentState) -> dict:
        """Run all post-hooks in order."""
        updates = {}
        
        for middleware in self.middlewares:
            if not middleware.enabled:
                continue
            
            result = middleware.post_hook(state)
            updates.update(result)
            state.update(result)
        
        return updates
    
    def get_all_stats(self) -> List[dict]:
        """Get stats from all middlewares."""
        return [mw.get_stats() for mw in self.middlewares]
    
    def enable_all(self):
        """Enable all middlewares."""
        for mw in self.middlewares:
            mw.enable()
    
    def disable_all(self):
        """Disable all middlewares."""
        for mw in self.middlewares:
            mw.disable()
    
    def get_middleware(self, name: str) -> Optional[BaseMiddleware]:
        """Get middleware by name."""
        for mw in self.middlewares:
            if mw.name == name:
                return mw
        return None
    
    def __repr__(self):
        return f"MiddlewareOrchestrator({len(self.middlewares)} middlewares)"

print("✅ MiddlewareOrchestrator defined!")

---
# Part 7: Complete Production System

Let's put it all together!

In [None]:
# Create all middlewares
auth_mw = AuthenticationMiddleware()

authz_mw = AuthorizationMiddleware({
    "update_salary": ["admin"],
    "terminate_employee": ["admin"],
    "check_salary": ["read"]
})

audit_mw = ComplianceAuditMiddleware(
    critical_operations=["update_salary", "terminate_employee"]
)

rate_limit_mw = AdvancedRateLimitMiddleware(
    max_requests=5,
    window_seconds=60,
    strategy=RateLimitStrategy.SLIDING_WINDOW
)

cache_mw = SmartCachingMiddleware(
    max_size=50,
    ttl_seconds=300,
    strategy=CacheStrategy.TTL
)

# Create orchestrator
orchestrator = MiddlewareOrchestrator([
    cache_mw,        # 1. Check cache first
    auth_mw,         # 2. Authenticate
    authz_mw,        # 3. Authorize
    rate_limit_mw,   # 4. Rate limit
    audit_mw         # 5. Audit
])

# Custom state
class ProductionHRState(AgentState):
    current_user_id: str = ""
    session_id: str = ""
    authenticated: bool = False
    authorized: bool = False
    rate_limited: bool = False
    cache_hit: bool = False
    user_profile: dict = {}
    user_permissions: list = []

# Create production agent
production_hr_agent = create_agent(
    model="openai:gpt-4o-mini",
    tools=[get_employee_info, check_salary, update_salary, terminate_employee],
    pre_model_hook=orchestrator.pre_hook,
    post_model_hook=orchestrator.post_hook,
    state_schema=ProductionHRState,
    checkpointer=InMemorySaver(),
    prompt="""You are an enterprise HR assistant with comprehensive security.
    
    All requests are:
    - Authenticated and authorized
    - Rate limited
    - Fully audited
    - Cached when appropriate
    
    Be professional and security-conscious."""
)

print("✅ Production HR Agent created!")
print("\nMiddleware Stack:")
for i, mw in enumerate(orchestrator.middlewares, 1):
    print(f"  {i}. {mw}")

## Test Complete System

In [None]:
config = {"configurable": {"thread_id": "prod_session_1"}}

print("=" * 70)
print("PRODUCTION HR SYSTEM - COMPLETE TEST")
print("=" * 70)

# Test 1: Basic query (authenticated user)
print("\n" + "="*70)
print("Test 1: Basic Query (Regular User)")
print("="*70)

result = production_hr_agent.invoke({
    "messages": [{"role": "user", "content": "Tell me about employee 102"}],
    "current_user_id": "101",  # Priya (regular user)
    "session_id": "session_001"
}, config)

print(f"\n🤖 Response: {result['messages'][-1].content[:100]}...")

# Test 2: Cached query (same question)
print("\n" + "="*70)
print("Test 2: Same Query (Should Hit Cache)")
print("="*70)

result = production_hr_agent.invoke({
    "messages": [{"role": "user", "content": "Tell me about employee 102"}],
    "current_user_id": "101",
    "session_id": "session_001"
}, config)

print(f"\n🤖 Response: {result['messages'][-1].content[:100]}...")

# Test 3: Unauthorized operation
print("\n" + "="*70)
print("Test 3: Unauthorized Salary Update (Regular User)")
print("="*70)

result = production_hr_agent.invoke({
    "messages": [{"role": "user", "content": "Update salary for employee 101 to 200000"}],
    "current_user_id": "101",  # Priya (no admin permission)
    "session_id": "session_001"
}, config)

print(f"\n🤖 Response: {result['messages'][-1].content}")

# Test 4: Authorized critical operation
print("\n" + "="*70)
print("Test 4: Authorized Salary Update (HR Director)")
print("="*70)

result = production_hr_agent.invoke({
    "messages": [{"role": "user", "content": "Update salary for employee 101 to 150000"}],
    "current_user_id": "103",  # Anjali (HR Director with admin)
    "session_id": "session_002"
}, config)

print(f"\n🤖 Response: {result['messages'][-1].content}")

# Test 5: Rate limiting
print("\n" + "="*70)
print("Test 5: Rate Limiting (6 requests from same user)")
print("="*70)

for i in range(6):
    print(f"\nRequest {i+1}/6:")
    result = production_hr_agent.invoke({
        "messages": [{"role": "user", "content": f"Show info for employee 10{i%3 + 1}"}],
        "current_user_id": "104",  # Arjun
        "session_id": "session_003"
    }, config)
    
    if "Rate Limit" in result['messages'][-1].content:
        print(f"❌ Rate limited!")
        break
    else:
        print(f"✅ Request successful")

print("\n" + "="*70)
print("SYSTEM STATISTICS")
print("="*70)

# Middleware stats
print("\n📊 Middleware Statistics:")
for stats in orchestrator.get_all_stats():
    print(f"\n{stats['name']}:")
    print(f"  Enabled: {stats['enabled']}")
    print(f"  Pre-calls: {stats['pre_calls']}")
    print(f"  Post-calls: {stats['post_calls']}")
    print(f"  Errors: {stats['errors']}")

# Cache stats
print("\n💾 Cache Performance:")
cache_stats = cache_mw.get_cache_stats()
for key, value in cache_stats.items():
    print(f"  {key}: {value}")

# Rate limit violations
print("\n⚠️  Rate Limit Violations:")
violations = rate_limit_mw.get_violations()
print(f"  Total: {len(violations)}")
for v in violations:
    print(f"  • User {v['user_id']}: {v['count']}/{v['limit']} at {v['timestamp']}")

# Compliance report
print("\n📋 Compliance Report:")
report = audit_mw.generate_compliance_report()
print(json.dumps(report, indent=2))

# Critical operations audit
print("\n🔴 Critical Operations Log:")
critical_ops = audit_mw.export_audit_log(AuditLevel.CRITICAL)
for op in critical_ops:
    print(f"\n  {op['request_id']}:")
    print(f"    User: {op['user_name']} ({op['user_id']})")
    print(f"    Operation: {op['request'][:60]}...")
    print(f"    Tools: {', '.join(op.get('tools_used', []))}")
    print(f"    Time: {op['timestamp']}")

print("\n✅ Complete production system test finished!")

---
# Summary

## Class-based Middleware Benefits

✅ **Reusability** - Write once, use everywhere  
✅ **State Management** - Maintain internal state  
✅ **Configuration** - Initialize with parameters  
✅ **Testability** - Easy to unit test  
✅ **Composability** - Chain multiple middlewares  
✅ **Maintainability** - Organized, modular code  

## Architecture Patterns

### 1. Base Class Pattern
```python
class BaseMiddleware(ABC):
    @abstractmethod
    def _pre_process(self, state) -> dict: pass
    
    @abstractmethod
    def _post_process(self, state) -> dict: pass
```

### 2. Orchestrator Pattern
```python
orchestrator = MiddlewareOrchestrator([
    CacheMiddleware(),
    AuthMiddleware(),
    AuditMiddleware()
])
```

### 3. Configuration Pattern
```python
middleware = RateLimitMiddleware(
    max_requests=10,
    window_seconds=60,
    strategy=RateLimitStrategy.SLIDING_WINDOW
)
```

## Production Checklist

✅ **Security Middlewares:**
- Authentication (verify identity)
- Authorization (check permissions)
- Input validation
- Output filtering

✅ **Performance Middlewares:**
- Caching (reduce LLM calls)
- Rate limiting (prevent abuse)
- Request throttling

✅ **Compliance Middlewares:**
- Audit logging
- Data retention
- Compliance reporting

✅ **Monitoring Middlewares:**
- Performance tracking
- Error logging
- Usage metrics

## Best Practices

1. **Single Responsibility** - Each middleware does one thing well
2. **Error Handling** - Graceful degradation
3. **Logging** - Structured logs for debugging
4. **Testing** - Unit test each middleware
5. **Documentation** - Clear docstrings
6. **Configuration** - Externalize settings
7. **Monitoring** - Track middleware performance

---

**Congratulations!** You now have production-ready class-based middleware patterns for enterprise HR systems! 🎉