# Decorator-based Middlewares for HR Agents - LangChain 1.0

**Module:** Decorator-based Middleware Patterns

**What you'll learn:**
- 🎯 `@before_agent` - Before agent starts (once per invocation)
- 🎯 `@before_model` - Before each model call  
- 🎯 `@after_model` - After each model response
- 🎯 `@after_agent` - After agent completes (once per invocation)
- 🔄 `@wrap_model_call` - Around each model call
- 🔧 `@wrap_tool_call` - Around each tool call
- 💬 `@dynamic_prompt` - Generate dynamic system prompts

**HR Use Cases:**
- Session initialization and cleanup
- Per-call authentication checks
- Tool execution monitoring
- Dynamic context injection
- Performance tracking

**Time:** 2-3 hours

---

## Setup: Install Dependencies

In [None]:
!pip install --pre -U langchain langchain-openai langgraph
!pip install langgraph-checkpoint-sqlite

In [None]:
# Imports
from google.colab import userdata
import os

os.environ['OPENAI_API_KEY'] = userdata.get('OPENAI_API_KEY')

from langchain.agents import create_agent, AgentState
from langchain_openai import ChatOpenAI
from langchain_core.tools import tool
from langgraph.checkpoint.memory import InMemorySaver
from typing import Annotated, Callable, Any
from datetime import datetime
from functools import wraps
import time
import json

print("✅ Setup complete!")

## Setup: HR Data and Tools

In [None]:
# Employee database
EMPLOYEES = {
    "101": {"name": "Priya Sharma", "department": "Engineering", "role": "Senior Developer", "salary": 120000},
    "102": {"name": "Rahul Verma", "department": "Engineering", "role": "Manager", "salary": 180000},
    "103": {"name": "Anjali Patel", "department": "HR", "role": "HR Director", "salary": 200000},
    "104": {"name": "Arjun Reddy", "department": "Sales", "role": "Team Lead", "salary": 150000},
    "105": {"name": "Sneha Gupta", "department": "Marketing", "role": "Specialist", "salary": 110000}
}

@tool
def get_employee_info(employee_id: Annotated[str, "Employee ID"]) -> str:
    """Get employee information."""
    if employee_id in EMPLOYEES:
        emp = EMPLOYEES[employee_id]
        return f"{emp['name']} - {emp['department']} - {emp['role']}"
    return f"Employee {employee_id} not found"

@tool
def check_salary(employee_id: Annotated[str, "Employee ID"]) -> str:
    """Check employee salary. SENSITIVE."""
    if employee_id in EMPLOYEES:
        return f"Salary: ₹{EMPLOYEES[employee_id]['salary']:,}"
    return "Not found"

@tool
def update_employee_record(employee_id: Annotated[str, "Employee ID"], field: Annotated[str, "Field"], value: Annotated[str, "Value"]) -> str:
    """Update employee record."""
    if employee_id in EMPLOYEES and field in EMPLOYEES[employee_id]:
        old_value = EMPLOYEES[employee_id][field]
        EMPLOYEES[employee_id][field] = value
        return f"Updated {field}: {old_value} → {value}"
    return "Update failed"

print(f"✅ Loaded {len(EMPLOYEES)} employees")

---
# Part 1: Session-Level Decorators

## @before_agent and @after_agent

**When they run:**
- `@before_agent`: Once at the START of agent invocation
- `@after_agent`: Once at the END of agent invocation

**Use cases:**
- Session initialization/cleanup
- Authentication
- Logging start/end of interactions
- Performance measurement
- Resource allocation/release

---

## Lab 1.1: @before_agent - Session Initialization

In [None]:
# Decorator implementation
class before_agent:
    """Decorator for before-agent hooks."""
    
    def __init__(self, func: Callable):
        self.func = func
        wraps(func)(self)
    
    def __call__(self, state: AgentState) -> dict:
        print(f"\n🚀 [@before_agent] {self.func.__name__}")
        return self.func(state)

# Example: Session initialization
@before_agent
def initialize_hr_session(state: AgentState) -> dict:
    """Initialize HR consultation session."""
    session_id = f"session_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    user_id = state.get("current_user_id", "unknown")
    
    print(f"   📋 Session ID: {session_id}")
    print(f"   👤 User: {user_id}")
    print(f"   🕐 Started: {datetime.now().isoformat()}")
    
    # Initialize session data
    return {
        "session_id": session_id,
        "session_start": datetime.now().isoformat(),
        "interaction_count": 0
    }

# Example: Authentication check
@before_agent  
def verify_user_authentication(state: AgentState) -> dict:
    """Verify user is authenticated."""
    user_id = state.get("current_user_id")
    
    if not user_id or user_id not in EMPLOYEES:
        print(f"   ❌ Authentication failed for: {user_id}")
        return {
            "authenticated": False,
            "messages": [("assistant", "❌ Authentication required. Please log in.")],
            "jump_to": "__end__"
        }
    
    print(f"   ✅ Authenticated: {EMPLOYEES[user_id]['name']}")
    return {"authenticated": True}

print("✅ @before_agent decorators defined!")

## Lab 1.2: @after_agent - Session Cleanup

In [None]:
class after_agent:
    """Decorator for after-agent hooks."""
    
    def __init__(self, func: Callable):
        self.func = func
        wraps(func)(self)
    
    def __call__(self, state: AgentState) -> dict:
        print(f"\n🏁 [@after_agent] {self.func.__name__}")
        return self.func(state)

# Example: Session summary
@after_agent
def log_session_summary(state: AgentState) -> dict:
    """Log session summary after completion."""
    session_id = state.get("session_id", "unknown")
    start_time = state.get("session_start", "unknown")
    
    if start_time != "unknown":
        duration = (datetime.now() - datetime.fromisoformat(start_time)).total_seconds()
        print(f"   ⏱️  Duration: {duration:.2f}s")
    
    print(f"   📋 Session: {session_id}")
    print(f"   💬 Messages: {len(state.get('messages', []))}")
    print(f"   🕐 Ended: {datetime.now().isoformat()}")
    
    return {}

# Example: Cleanup resources
@after_agent
def cleanup_resources(state: AgentState) -> dict:
    """Release resources after session."""
    print(f"   🧹 Cleaning up session resources...")
    print(f"   💾 Saving session data...")
    print(f"   ✅ Cleanup complete")
    return {}

print("✅ @after_agent decorators defined!")

## Lab 1.3: Combining Session Decorators

In [None]:
# Combine session-level hooks
def combined_before_agent_hook(state: AgentState) -> dict:
    """Run all before-agent decorators."""
    # Run authentication first
    result = verify_user_authentication(state)
    if result.get("jump_to"):
        return result
    
    # Initialize session
    result = initialize_hr_session(state)
    return result

def combined_after_agent_hook(state: AgentState) -> dict:
    """Run all after-agent decorators."""
    log_session_summary(state)
    cleanup_resources(state)
    return {}

# For demo, we'll simulate by calling these in pre/post hooks
class SessionAgentState(AgentState):
    current_user_id: str = ""
    session_id: str = ""
    session_start: str = ""
    authenticated: bool = False

print("✅ Session hooks combined!")

---
# Part 2: Call-Level Decorators

## @before_model and @after_model

**When they run:**
- `@before_model`: Before EACH model call (can be multiple per agent invocation)
- `@after_model`: After EACH model response

**Use cases:**
- Per-call logging
- Token tracking
- Response validation
- Cost calculation

---

## Lab 2.1: @before_model - Per-Call Logging

In [None]:
# Call counter
model_call_stats = {
    "total_calls": 0,
    "call_history": []
}

class before_model:
    """Decorator for before-model hooks."""
    
    def __init__(self, func: Callable):
        self.func = func
        wraps(func)(self)
    
    def __call__(self, state: AgentState) -> dict:
        print(f"\n🤖 [@before_model] {self.func.__name__}")
        return self.func(state)

@before_model
def log_model_call(state: AgentState) -> dict:
    """Log each model call."""
    model_call_stats["total_calls"] += 1
    call_num = model_call_stats["total_calls"]
    
    messages = state.get("messages", [])
    last_msg = messages[-1].content if messages else "No message"
    
    call_info = {
        "call_number": call_num,
        "timestamp": datetime.now().isoformat(),
        "user_message": last_msg[:50] + "..." if len(last_msg) > 50 else last_msg
    }
    
    model_call_stats["call_history"].append(call_info)
    
    print(f"   📞 Model Call #{call_num}")
    print(f"   💬 Query: {last_msg[:50]}...")
    
    return {"current_model_call": call_num}

@before_model
def check_rate_limit_per_call(state: AgentState) -> dict:
    """Check if rate limit exceeded."""
    # Simple demo: limit 5 model calls per session
    if model_call_stats["total_calls"] > 5:
        print(f"   ⚠️  Rate limit warning: {model_call_stats['total_calls']} calls")
    
    return {}

print("✅ @before_model decorators defined!")

## Lab 2.2: @after_model - Response Tracking

In [None]:
response_stats = {
    "total_tokens": 0,
    "total_cost": 0.0
}

class after_model:
    """Decorator for after-model hooks."""
    
    def __init__(self, func: Callable):
        self.func = func
        wraps(func)(self)
    
    def __call__(self, state: AgentState) -> dict:
        print(f"\n✅ [@after_model] {self.func.__name__}")
        return self.func(state)

@after_model
def track_token_usage(state: AgentState) -> dict:
    """Track token usage per model call."""
    messages = state.get("messages", [])
    
    # Rough estimate
    tokens = sum(len(m.content.split()) * 1.3 for m in messages if hasattr(m, 'content'))
    tokens = int(tokens)
    
    cost = (tokens / 1000) * 0.002  # $0.002 per 1K tokens
    
    response_stats["total_tokens"] += tokens
    response_stats["total_cost"] += cost
    
    call_num = state.get("current_model_call", "?")
    print(f"   💰 Call #{call_num}: ~{tokens} tokens (${cost:.4f})")
    print(f"   📊 Session total: ~{response_stats['total_tokens']} tokens (${response_stats['total_cost']:.4f})")
    
    return {}

@after_model
def validate_response_quality(state: AgentState) -> dict:
    """Validate model response."""
    messages = state.get("messages", [])
    if messages:
        response = messages[-1].content
        
        # Simple quality checks
        if len(response) < 10:
            print(f"   ⚠️  Warning: Short response ({len(response)} chars)")
        
        if "error" in response.lower() or "sorry" in response.lower():
            print(f"   ⚠️  Warning: Possible error response detected")
    
    return {}

print("✅ @after_model decorators defined!")

---
# Part 3: Wrap-Style Decorators

## @wrap_model_call - Intercept Model Calls

**What it does:**
- Wraps the entire model call
- Can run code before AND after
- Can modify request/response
- Can skip the call entirely

**Use cases:**
- Performance timing
- Request/response transformation
- Error handling
- Caching

---

## Lab 3.1: @wrap_model_call - Performance Tracking

In [None]:
performance_stats = {
    "calls": [],
    "total_time": 0.0
}

class wrap_model_call:
    """Decorator to wrap model calls."""
    
    def __init__(self, func: Callable):
        self.func = func
        wraps(func)(self)
    
    def __call__(self, model_func: Callable) -> Callable:
        """Wrap the model call function."""
        
        @wraps(model_func)
        def wrapper(*args, **kwargs):
            print(f"\n🔄 [@wrap_model_call] {self.func.__name__}")
            
            # BEFORE model call
            result = self.func("before", *args, **kwargs)
            
            # ACTUAL model call
            output = model_func(*args, **kwargs)
            
            # AFTER model call
            result = self.func("after", output, *args, **kwargs)
            
            return output
        
        return wrapper

# Simpler hook-based approach for demo
def measure_model_performance_pre(state: AgentState) -> dict:
    """Start performance timer."""
    print(f"\n⏱️  [Performance] Starting model call...")
    return {"model_call_start": time.time()}

def measure_model_performance_post(state: AgentState) -> dict:
    """End performance timer."""
    start_time = state.get("model_call_start")
    if start_time:
        duration = time.time() - start_time
        performance_stats["calls"].append(duration)
        performance_stats["total_time"] += duration
        
        avg_time = performance_stats["total_time"] / len(performance_stats["calls"])
        
        print(f"\n⏱️  [Performance] Model call completed")
        print(f"   ⚡ This call: {duration:.3f}s")
        print(f"   📊 Average: {avg_time:.3f}s")
        print(f"   📈 Total calls: {len(performance_stats['calls'])}")
    
    return {}

print("✅ Performance tracking defined!")

## Lab 3.2: @wrap_tool_call - Tool Execution Monitoring

In [None]:
tool_execution_stats = {
    "total_calls": 0,
    "by_tool": {},
    "failures": []
}

def monitor_tool_execution(state: AgentState) -> dict:
    """Monitor tool calls after model response."""
    messages = state.get("messages", [])
    
    for msg in messages:
        if hasattr(msg, 'tool_calls') and msg.tool_calls:
            for tool_call in msg.tool_calls:
                tool_name = tool_call.get('name', 'unknown')
                
                tool_execution_stats["total_calls"] += 1
                tool_execution_stats["by_tool"][tool_name] = tool_execution_stats["by_tool"].get(tool_name, 0) + 1
                
                print(f"\n🔧 [Tool Monitor] {tool_name}")
                print(f"   📞 Total tool calls: {tool_execution_stats['total_calls']}")
                print(f"   🎯 This tool: {tool_execution_stats['by_tool'][tool_name]} times")
                print(f"   📋 Args: {tool_call.get('args', {})}")
    
    return {}

print("✅ Tool monitoring defined!")

---
# Part 4: @dynamic_prompt - Context-Aware Prompts

**What it does:**
- Generates dynamic system prompts
- Based on current state/context
- Equivalent to @wrap_model_call that modifies prompt

**Use cases:**
- User-specific instructions
- Time-based context
- Role-based prompts
- Dynamic policies

---

## Lab 4.1: Dynamic Prompt Generation

In [None]:
def generate_dynamic_prompt(state: AgentState) -> dict:
    """Generate context-aware system prompt."""
    user_id = state.get("current_user_id", "unknown")
    session_time = datetime.now().strftime("%A, %I:%M %p")
    
    # Get user info
    user_context = ""
    if user_id in EMPLOYEES:
        emp = EMPLOYEES[user_id]
        user_context = f"""\nCURRENT USER CONTEXT:
- Name: {emp['name']}
- ID: {user_id}
- Department: {emp['department']}
- Role: {emp['role']}
"""
    
    # Time-based greeting
    hour = datetime.now().hour
    if hour < 12:
        greeting = "Good morning"
    elif hour < 17:
        greeting = "Good afternoon"
    else:
        greeting = "Good evening"
    
    dynamic_instructions = f"""\n{greeting}! Today is {session_time}.
{user_context}
When the user says 'my' or 'I', they are referring to employee {user_id}.
Be professional, helpful, and personalized.
"""
    
    print(f"\n💬 [Dynamic Prompt] Generated context for {user_id}")
    print(f"   ⏰ Time context: {greeting}, {session_time}")
    print(f"   👤 User context: {emp['name'] if user_id in EMPLOYEES else 'Unknown'}")
    
    # Store for use in prompt
    return {"dynamic_context": dynamic_instructions}

print("✅ Dynamic prompt generator defined!")

---
# Part 5: Complete Demo - All Decorators Together

Let's create an HR agent that uses ALL decorator patterns!

In [None]:
# Reset stats
model_call_stats = {"total_calls": 0, "call_history": []}
response_stats = {"total_tokens": 0, "total_cost": 0.0}
performance_stats = {"calls": [], "total_time": 0.0}
tool_execution_stats = {"total_calls": 0, "by_tool": {}, "failures": []}

# Comprehensive pre-model hook
def comprehensive_pre_hook(state: AgentState) -> dict:
    """All pre-processing decorators."""
    updates = {}
    
    # Session-level (only on first call)
    if not state.get("session_id"):
        result = combined_before_agent_hook(state)
        if result.get("jump_to"):
            return result
        updates.update(result)
    
    # Call-level (every call)
    updates.update(log_model_call(state))
    check_rate_limit_per_call(state)
    
    # Performance tracking
    updates.update(measure_model_performance_pre(state))
    
    # Dynamic prompt
    updates.update(generate_dynamic_prompt(state))
    
    return updates

# Comprehensive post-model hook
def comprehensive_post_hook(state: AgentState) -> dict:
    """All post-processing decorators."""
    # Call-level
    track_token_usage(state)
    validate_response_quality(state)
    
    # Performance
    measure_model_performance_post(state)
    
    # Tool monitoring
    monitor_tool_execution(state)
    
    # Session-level cleanup (only on last interaction)
    # In practice, this would be triggered by agent completion
    
    return {}

# Create comprehensive HR agent
comprehensive_hr_agent = create_agent(
    model="openai:gpt-4o-mini",
    tools=[get_employee_info, check_salary, update_employee_record],
    pre_model_hook=comprehensive_pre_hook,
    post_model_hook=comprehensive_post_hook,
    state_schema=SessionAgentState,
    checkpointer=InMemorySaver(),
    prompt="""You are a comprehensive HR assistant with full decorator monitoring.
    
    Help employees with their HR needs while maintaining:
    - Session tracking
    - Performance monitoring  
    - Cost tracking
    - Tool execution logging
    
    Be professional and use the dynamic context provided."""
)

print("✅ Comprehensive HR Agent created!")
print("\nActive decorators:")
print("  🚀 @before_agent: Session init, Authentication")
print("  🤖 @before_model: Logging, Rate limit, Performance, Dynamic prompt")
print("  ✅ @after_model: Token tracking, Validation")
print("  🔧 Tool monitoring: Execution tracking")
print("  🏁 @after_agent: Session summary, Cleanup")

## Test the Complete System

In [None]:
config = {"configurable": {"thread_id": "decorator_demo_1"}}

print("=" * 70)
print("COMPREHENSIVE DECORATOR DEMO")
print("=" * 70)

# Test 1: Basic query
print("\n" + "="*70)
print("Test 1: Employee Info Query")
print("="*70)

result = comprehensive_hr_agent.invoke({
    "messages": [{"role": "user", "content": "Tell me about my role and department"}],
    "current_user_id": "101"
}, config)

print(f"\n🤖 Response: {result['messages'][-1].content[:150]}...")

# Test 2: Salary check (sensitive)
print("\n" + "="*70)
print("Test 2: Salary Check (Sensitive Operation)")
print("="*70)

result = comprehensive_hr_agent.invoke({
    "messages": [{"role": "user", "content": "What is my current salary?"}],
    "current_user_id": "101"
}, config)

print(f"\n🤖 Response: {result['messages'][-1].content}")

# Test 3: Another query to show multiple calls
print("\n" + "="*70)
print("Test 3: Department Info")
print("="*70)

result = comprehensive_hr_agent.invoke({
    "messages": [{"role": "user", "content": "Who else works in my department?"}],
    "current_user_id": "101"
}, config)

print(f"\n🤖 Response: {result['messages'][-1].content[:150]}...")

# Session end - show final stats
print("\n" + "="*70)
print("SESSION END - FINAL STATISTICS")
print("="*70)

combined_after_agent_hook(result)

print(f"\n📊 Model Call Statistics:")
print(f"   Total calls: {model_call_stats['total_calls']}")

print(f"\n💰 Token & Cost Statistics:")
print(f"   Total tokens: ~{response_stats['total_tokens']}")
print(f"   Total cost: ${response_stats['total_cost']:.4f}")

print(f"\n⏱️  Performance Statistics:")
if performance_stats['calls']:
    print(f"   Total time: {performance_stats['total_time']:.3f}s")
    print(f"   Average: {performance_stats['total_time']/len(performance_stats['calls']):.3f}s per call")
    print(f"   Min: {min(performance_stats['calls']):.3f}s")
    print(f"   Max: {max(performance_stats['calls']):.3f}s")

print(f"\n🔧 Tool Execution Statistics:")
print(f"   Total tool calls: {tool_execution_stats['total_calls']}")
if tool_execution_stats['by_tool']:
    print(f"   Breakdown:")
    for tool, count in tool_execution_stats['by_tool'].items():
        print(f"      • {tool}: {count} times")

print("\n✅ All decorators executed successfully!")

---
# Summary

## Decorator Types and Usage

| Decorator | Runs | Use For | Frequency |
|-----------|------|---------|----------|
| `@before_agent` | Start of invocation | Session init, auth | Once per invocation |
| `@before_model` | Before each LLM call | Logging, rate limits | Multiple per invocation |
| `@after_model` | After each LLM response | Token tracking, validation | Multiple per invocation |
| `@after_agent` | End of invocation | Cleanup, summary | Once per invocation |
| `@wrap_model_call` | Around LLM call | Performance, caching | Multiple per invocation |
| `@wrap_tool_call` | Around tool execution | Tool monitoring | Per tool call |
| `@dynamic_prompt` | Before LLM call | Context injection | Multiple per invocation |

## Best Practices

✅ **Session-Level Decorators:**
- Use @before_agent for one-time setup
- Use @after_agent for cleanup and summary
- Don't put per-call logic here

✅ **Call-Level Decorators:**
- Use @before_model for per-call checks
- Use @after_model for per-response processing
- Keep lightweight (runs frequently)

✅ **Wrap Decorators:**
- Use for cross-cutting concerns
- Good for performance measurement
- Can modify request/response

✅ **Dynamic Prompts:**
- Generate based on current state
- Include user context
- Add time-based instructions

## Production Considerations

- **Performance**: Decorators add overhead - keep them lightweight
- **Error Handling**: Wrap decorator logic in try-catch
- **Logging**: Use structured logging for better analysis
- **Metrics**: Export to monitoring systems
- **Testing**: Test decorators independently

---

**Next:** Explore class-based middleware patterns for more complex scenarios!